diff --git a/Wavefunctions/CubicSplineSolver.ipynb b/Wavefunctions/CubicSplineSolver.ipynb new file mode 100644 index 0000000..7b01ff5 --- /dev/null +++ b/Wavefunctions/CubicSplineSolver.ipynb @@ -0,0 +1,1856 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sympy import *\n", + "from IPython.display import display, Latex, HTML, Markdown\n", + "init_printing()\n", + "from eqn_manip import *\n", + "from codegen_extras import *\n", + "import codegen_extras\n", + "from importlib import reload\n", + "from sympy.codegen.ast import Assignment, For, CodeBlock, real, Variable, Pointer, Declaration\n", + "from sympy.codegen.cnodes import void" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cubic Spline solver - derivation and code generation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tridiagonal Solver\n", + "From Wikipedia: https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm\n", + "\n", + "In the future it would be good to derive these equations from Gaussian elimintation (as on the Wikipedia page), but for now they are simply given." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "n = Symbol('n', integer=True)\n", + "i = Symbol('i', integer=True)\n", + "x = IndexedBase('x',shape=(n,))\n", + "dp = IndexedBase(\"d'\",shape=(n,))\n", + "cp = IndexedBase(\"c'\",shape=(n,))\n", + "a = IndexedBase(\"a\",shape=(n,))\n", + "b = IndexedBase(\"b\",shape=(n,))\n", + "c = IndexedBase(\"c\",shape=(n,))\n", + "d = IndexedBase(\"d\",shape=(n,))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${c'}_{0} = \\frac{{c}_{0}}{{b}_{0}}$$" + ], + "text/plain": [ + " c[0]\n", + "c'[0] = ────\n", + " b[0]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${d'}_{0} = \\frac{{d}_{0}}{{b}_{0}}$$" + ], + "text/plain": [ + " d[0]\n", + "d'[0] = ────\n", + " b[0]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${d'}_{i} = \\frac{- {a}_{i} {d'}_{i - 1} + {d}_{i}}{- {a}_{i} {c'}_{i - 1} + {b}_{i}}$$" + ], + "text/plain": [ + " -a[i]⋅d'[i - 1] + d[i]\n", + "d'[i] = ──────────────────────\n", + " -a[i]⋅c'[i - 1] + b[i]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${c'}_{i} = \\frac{{c}_{i}}{- {a}_{i} {c'}_{i - 1} + {b}_{i}}$$" + ], + "text/plain": [ + " c[i] \n", + "c'[i] = ──────────────────────\n", + " -a[i]⋅c'[i - 1] + b[i]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# forward sweep\n", + "# start/end using the natural range for math notation\n", + "#start = 1\n", + "#end = n\n", + "# Use the C++ range 0,n-1\n", + "start = 0\n", + "end = n-1\n", + "teq1 = Eq(cp[start], c[start]/b[start])\n", + "display(teq1)\n", + "teq2 = Eq(dp[start], d[start]/b[start])\n", + "display(teq2)\n", + "teq3 = Eq(dp[i],(d[i] - dp[i-1]*a[i])/ (b[i] - cp[i-1]*a[i]))\n", + "display(teq3)\n", + "teq4 = Eq(cp[i],c[i]/(b[i] - cp[i-1]*a[i]))\n", + "display(teq4)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${x}_{n - 1} = {d'}_{n - 1}$$" + ], + "text/plain": [ + "x[n - 1] = d'[n - 1]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${x}_{i} = - {c'}_{i} {x}_{i + 1} + {d'}_{i}$$" + ], + "text/plain": [ + "x[i] = -c'[i]⋅x[i + 1] + d'[i]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# backward sweep\n", + "teq5 = Eq(x[end],dp[end])\n", + "display(teq5)\n", + "teq6 = Eq(x[i],dp[i] - cp[i]*x[i+1])\n", + "display(teq6)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cubic Spline equations\n", + "Start with uniform knot spacing. The derivation is easier to see than in the case with general knot spacing." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$y = t^{3} d + t^{2} c + t b + a$$" + ], + "text/plain": [ + " 3 2 \n", + "y = t ⋅d + t ⋅c + t⋅b + a" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${y}_{i} = t^{3} {d}_{i} + t^{2} {c}_{i} + t {b}_{i} + {a}_{i}$$" + ], + "text/plain": [ + " 3 2 \n", + "y[i] = t ⋅d[i] + t ⋅c[i] + t⋅b[i] + a[i]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Distance from the previous knot, for the case of uniform knot spacing\n", + "t = Symbol('t')\n", + "\n", + "# Number of knots\n", + "n = Symbol('n', integer=True)\n", + "i = Symbol('i', integer=True)\n", + "# Function values to intepolated at the knots\n", + "y = IndexedBase('y',shape=(n,))\n", + "\n", + "# Coefficients of the spline function\n", + "a,b,c,d = [IndexedBase(s, shape=(n,)) for s in 'a b c d'.split()]\n", + "\n", + "# Cubic spline function\n", + "s = a + b*t + c*t*t + d*t**3\n", + "display(Eq(y,s))\n", + "\n", + "# With indexed variables\n", + "si = a[i] + b[i]*t + c[i]*t*t + d[i]*t**3\n", + "display(Eq(y[i],si))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Strategy\n", + "To eventually reduce the equations to a tridiagonal form, express the equations in terms of the second derivative ($E$).\n", + "See the MathWorld page for cubic splines, which derives the equations in terms of the first derivative ($D$).\n", + "\n", + "http://mathworld.wolfram.com/CubicSpline.html" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${a}_{i} = {y}_{i}$$" + ], + "text/plain": [ + "a[i] = y[i]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Value at knots (t=0)\n", + "sp1 = Eq(si.subs(t,0), y[i])\n", + "sp1" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${a}_{i} + {b}_{i} + {c}_{i} + {d}_{i} = {y}_{i + 1}$$" + ], + "text/plain": [ + "a[i] + b[i] + c[i] + d[i] = y[i + 1]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Value at knots (t=1)\n", + "sp2 = Eq(si.subs(t,1), y[i+1])\n", + "sp2" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${E}_{i} = 2 {c}_{i}$$" + ], + "text/plain": [ + "E[i] = 2⋅c[i]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Express the second derivative at the beginning of the interval in terms of E\n", + "E = IndexedBase('E',shape=(n,))\n", + "sp3 = Eq(E[i], diff(si,t,2).subs(t,0))\n", + "sp3" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${E}_{i + 1} = 2 {c}_{i} + 6 {d}_{i}$$" + ], + "text/plain": [ + "E[i + 1] = 2⋅c[i] + 6⋅d[i]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Express the second derivative at the end of the interval in terms of E\n", + "sp4 = Eq(E[i+1], diff(si,t,2).subs(t,1))\n", + "sp4" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${b}_{i} + 2 {c}_{i} + 3 {d}_{i} = {b}_{i + 1}$$" + ], + "text/plain": [ + "b[i] + 2⋅c[i] + 3⋅d[i] = b[i + 1]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Continuity of the first derivative\n", + "sp5 = Eq(diff(si,t).subs(t,1), diff(si,t).subs(t,0).subs(i,i+1))\n", + "sp5" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### For general spacing of the knots" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "L = IndexedBase('L',shape=(n,)) # L[i] = x[i+1] - x[i]\n", + "t = Symbol('t')\n", + "x = IndexedBase('x',shape=(n,))\n", + "\n", + "si = a[i] + b[i]*t + c[i]*t*t + d[i]*t**3" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${a}_{i} = {y}_{i}$$" + ], + "text/plain": [ + "a[i] = y[i]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Value at knots (t=0)\n", + "sp1 = Eq(si.subs(t,0), y[i])\n", + "sp1" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${L}_{i}^{3} {d}_{i} + {L}_{i}^{2} {c}_{i} + {L}_{i} {b}_{i} + {a}_{i} = {y}_{i + 1}$$" + ], + "text/plain": [ + " 3 2 \n", + "L[i] ⋅d[i] + L[i] ⋅c[i] + L[i]⋅b[i] + a[i] = y[i + 1]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Value at next knot\n", + "sp2 = Eq(si.subs(t,L[i]), y[i+1])\n", + "sp2" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${E}_{i} = 2 {c}_{i}$$" + ], + "text/plain": [ + "E[i] = 2⋅c[i]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Express the second derivative at the beginning of the interval in terms of E\n", + "E = IndexedBase('E',shape=(n,))\n", + "sp3 = Eq(E[i], diff(si,t,2).subs(t,0))\n", + "sp3" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${E}_{i + 1} = 6 {L}_{i} {d}_{i} + 2 {c}_{i}$$" + ], + "text/plain": [ + "E[i + 1] = 6⋅L[i]⋅d[i] + 2⋅c[i]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Express the second derivative at the end of the interval in terms of E\n", + "sp4 = Eq(E[i+1], diff(si,t,2).subs(t,L[i]))\n", + "sp4" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\left \\{ {a}_{i} : {y}_{i}, \\quad {b}_{i} : \\frac{- \\frac{\\left({E}_{i + 1} + 2 {E}_{i}\\right) {L}_{i}^{2}}{6} + {y}_{i + 1} - {y}_{i}}{{L}_{i}}, \\quad {c}_{i} : \\frac{{E}_{i}}{2}, \\quad {d}_{i} : \\frac{{E}_{i + 1} - {E}_{i}}{6 {L}_{i}}\\right \\}$$" + ], + "text/plain": [ + "⎧ 2 \n", + "⎪ (E[i + 1] + 2⋅E[i])⋅L[i] \n", + "⎪ - ───────────────────────── + y[i + 1] - y[i] \n", + "⎨ 6 E[i] \n", + "⎪a[i]: y[i], b[i]: ─────────────────────────────────────────────, c[i]: ────, \n", + "⎪ L[i] 2 \n", + "⎩ \n", + "\n", + " ⎫\n", + " ⎪\n", + " ⎪\n", + " E[i + 1] - E[i]⎬\n", + "d[i]: ───────────────⎪\n", + " 6⋅L[i] ⎪\n", + " ⎭" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + " # Solve for spline coefficients in terms of E's\n", + "sln = solve([sp1,sp2,sp3,sp4], [a[i],b[i],c[i],d[i]])\n", + "sln" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\left \\{ {a}_{i + 1} : {y}_{i + 1}, \\quad {b}_{i + 1} : \\frac{- \\frac{\\left(2 {E}_{i + 1} + {E}_{i + 2}\\right) {L}_{i + 1}^{2}}{6} - {y}_{i + 1} + {y}_{i + 2}}{{L}_{i + 1}}, \\quad {c}_{i + 1} : \\frac{{E}_{i + 1}}{2}, \\quad {d}_{i + 1} : \\frac{- {E}_{i + 1} + {E}_{i + 2}}{6 {L}_{i + 1}}\\right \\}$$" + ], + "text/plain": [ + "⎧ 2 \n", + "⎪ (2⋅E[i + 1] + E[i + 2])⋅L[i + 1] \n", + "⎪ - ───────────────────────────────── - y[i + 1] \n", + "⎨ 6 \n", + "⎪a[i + 1]: y[i + 1], b[i + 1]: ───────────────────────────────────────────────\n", + "⎪ L[i + 1] \n", + "⎩ \n", + "\n", + " ⎫\n", + " ⎪\n", + "+ y[i + 2] ⎪\n", + " E[i + 1] -E[i + 1] + E[i + 2]⎬\n", + "──────────, c[i + 1]: ────────, d[i + 1]: ────────────────────⎪\n", + " 2 6⋅L[i + 1] ⎪\n", + " ⎭" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# also for i+1\n", + "sln1 = {k.subs(i,i+1):v.subs(i,i+1) for k,v in sln.items()}\n", + "sln1" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$3 {L}_{i}^{2} {d}_{i} + 2 {L}_{i} {c}_{i} + {b}_{i} = {b}_{i + 1}$$" + ], + "text/plain": [ + " 2 \n", + "3⋅L[i] ⋅d[i] + 2⋅L[i]⋅c[i] + b[i] = b[i + 1]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Continuity of first derivatives at knots\n", + "# This will define the tridiagonal system to be solved\n", + "sp5 = Eq(diff(si,t).subs(t,L[i]), diff(si,t).subs(i, i+1).subs(t,0))\n", + "sp5" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\frac{{E}_{i + 1} {L}_{i}}{3} + \\frac{{E}_{i} {L}_{i}}{6} + \\frac{{y}_{i + 1}}{{L}_{i}} - \\frac{{y}_{i}}{{L}_{i}} = - \\frac{{E}_{i + 1} {L}_{i + 1}}{3} - \\frac{{E}_{i + 2} {L}_{i + 1}}{6} - \\frac{{y}_{i + 1}}{{L}_{i + 1}} + \\frac{{y}_{i + 2}}{{L}_{i + 1}}$$" + ], + "text/plain": [ + "E[i + 1]⋅L[i] E[i]⋅L[i] y[i + 1] y[i] E[i + 1]⋅L[i + 1] E[i + 2]⋅L\n", + "───────────── + ───────── + ──────── - ──── = - ───────────────── - ──────────\n", + " 3 6 L[i] L[i] 3 6 \n", + "\n", + "[i + 1] y[i + 1] y[i + 2]\n", + "─────── - ──────── + ────────\n", + " L[i + 1] L[i + 1]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sp6 = sp5.subs(sln).subs(sln1)\n", + "sp7 = expand(sp6)\n", + "sp7" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\frac{{E}_{i + 1} {L}_{i + 1}}{3} + \\frac{{E}_{i + 1} {L}_{i}}{3} + \\frac{{E}_{i + 2} {L}_{i + 1}}{6} + \\frac{{E}_{i} {L}_{i}}{6} = - \\frac{{y}_{i + 1}}{{L}_{i}} + \\frac{{y}_{i}}{{L}_{i}} - \\frac{{y}_{i + 1}}{{L}_{i + 1}} + \\frac{{y}_{i + 2}}{{L}_{i + 1}}$$" + ], + "text/plain": [ + "E[i + 1]⋅L[i + 1] E[i + 1]⋅L[i] E[i + 2]⋅L[i + 1] E[i]⋅L[i] y[i + 1]\n", + "───────────────── + ───────────── + ───────────────── + ───────── = - ────────\n", + " 3 3 6 6 L[i] \n", + "\n", + " y[i] y[i + 1] y[i + 2]\n", + " + ──── - ──────── + ────────\n", + " L[i] L[i + 1] L[i + 1]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$2 {E}_{i + 1} {L}_{i + 1} + 2 {E}_{i + 1} {L}_{i} + {E}_{i + 2} {L}_{i + 1} + {E}_{i} {L}_{i} = - \\frac{6 {y}_{i + 1}}{{L}_{i}} + \\frac{6 {y}_{i}}{{L}_{i}} - \\frac{6 {y}_{i + 1}}{{L}_{i + 1}} + \\frac{6 {y}_{i + 2}}{{L}_{i + 1}}$$" + ], + "text/plain": [ + " 6⋅y[\n", + "2⋅E[i + 1]⋅L[i + 1] + 2⋅E[i + 1]⋅L[i] + E[i + 2]⋅L[i + 1] + E[i]⋅L[i] = - ────\n", + " L\n", + "\n", + "i + 1] 6⋅y[i] 6⋅y[i + 1] 6⋅y[i + 2]\n", + "────── + ────── - ────────── + ──────────\n", + "[i] L[i] L[i + 1] L[i + 1] " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sp8 = divide_terms(sp7, [E[i],E[i+1],E[i+2]], [y[i],y[i+1],y[i+2]])\n", + "display(sp8)\n", + "sp9 = mult_eqn(sp8,6)\n", + "display(sp9)\n", + "\n", + "# The index 'i' used in the cubic spline equations is not the same 'i' used\n", + "# in the tridigonal solver. Here we need to make them match.\n", + "# The first foundary condition will the equation at index at 0.\n", + "# Adjust the indexing on this equation so i=1 is the index of the first continuity interval match\n", + "sp9 = sp9.subs(i,i-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${L}_{i - 1}$$" + ], + "text/plain": [ + "L[i - 1]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$2 {L}_{i - 1} + 2 {L}_{i}$$" + ], + "text/plain": [ + "2⋅L[i - 1] + 2⋅L[i]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${L}_{i}$$" + ], + "text/plain": [ + "L[i]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Extract the three coefficients in each row for the general case\n", + "symlist = [E[i-1],E[i],E[i+1],E[i+2]]\n", + "coeff1 = get_coeff_for(sp9.lhs, E[i-1], symlist)\n", + "display(coeff1)\n", + "coeff2 = get_coeff_for(sp9.lhs, E[i], symlist)\n", + "display(coeff2)\n", + "coeff3 = get_coeff_for(sp9.lhs, E[i+1], symlist)\n", + "display(coeff3)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${E}_{0} = 0$$" + ], + "text/plain": [ + "E[0] = 0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${E}_{n - 1} = 0$$" + ], + "text/plain": [ + "E[n - 1] = 0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$1$$" + ], + "text/plain": [ + "1" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$0$$" + ], + "text/plain": [ + "0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$0$$" + ], + "text/plain": [ + "0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$1$$" + ], + "text/plain": [ + "1" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now get the coefficients for the boundary conditions (first row and last row)\n", + "\n", + "# Natural BC\n", + "bc_natural_start = Eq(E[i].subs(i,0),0)\n", + "display(bc_natural_start)\n", + "bc_natural_end = Eq(E[i].subs(i,end),0)\n", + "display(bc_natural_end)\n", + "\n", + "# The coefficients and RHS for this BC are pretty simple. but we will follow\n", + "# a deterministic path for derivation anyway.\n", + "bc_natural_start_coeff1 = get_coeff_for(bc_natural_start.lhs, E[start],[E[start]])\n", + "display(bc_natural_start_coeff1)\n", + "bc_natural_start_coeff2 = get_coeff_for(bc_natural_start.lhs, E[start+1],[E[start],E[start+1]])\n", + "display(bc_natural_start_coeff2)\n", + "bc_natural_end_coeff1 = get_coeff_for(bc_natural_end.lhs, E[end-1],[E[end]])\n", + "display(bc_natural_end_coeff1)\n", + "bc_natural_end_coeff2 = get_coeff_for(bc_natural_end.lhs, E[end],[E[end]])\n", + "bc_natural_end_coeff2" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\frac{- \\frac{\\left(2 {E}_{0} + {E}_{1}\\right) {L}_{0}^{2}}{6} - {y}_{0} + {y}_{1}}{{L}_{0}} = yp_{0}$$" + ], + "text/plain": [ + " 2 \n", + " (2⋅E[0] + E[1])⋅L[0] \n", + "- ───────────────────── - y[0] + y[1] \n", + " 6 \n", + "───────────────────────────────────── = yp₀\n", + " L[0] " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$- 2 {E}_{0} {L}_{0} - {E}_{1} {L}_{0} = 6 yp_{0} + \\frac{6 {y}_{0}}{{L}_{0}} - \\frac{6 {y}_{1}}{{L}_{0}}$$" + ], + "text/plain": [ + " 6⋅y[0] 6⋅y[1]\n", + "-2⋅E[0]⋅L[0] - E[1]⋅L[0] = 6⋅yp₀ + ────── - ──────\n", + " L[0] L[0] " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$- 2 {L}_{0}$$" + ], + "text/plain": [ + "-2⋅L[0]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$- {L}_{0}$$" + ], + "text/plain": [ + "-L[0]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# BC - first derivative specified at the beginning of the range\n", + "yp0 = Symbol('yp0')\n", + "eqbc1=Eq(diff(si,t).subs(t,0).subs(sln).subs(i,0), yp0)\n", + "display(eqbc1)\n", + "eqbc1b = divide_terms(expand(eqbc1),[E[0],E[1]],[y[0],y[1],yp0])\n", + "eqbc1c = mult_eqn(eqbc1b, 6)\n", + "display(eqbc1c)\n", + "bc_firstd_start_coeff1 = get_coeff_for(eqbc1c.lhs, E[0], [E[0],E[1]])\n", + "display(bc_firstd_start_coeff1)\n", + "bc_firstd_start_coeff2 = get_coeff_for(eqbc1c.lhs, E[1], [E[0],E[1]])\n", + "display(bc_firstd_start_coeff2)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\begin{cases} - 2 {L}_{0} & \\text{for}\\: yp_{0} < 9.9 \\cdot 10^{29} \\\\1 & \\text{otherwise} \\end{cases}$$" + ], + "text/plain": [ + "⎧-2⋅L[0] for yp₀ < 9.9e+29\n", + "⎨ \n", + "⎩ 1 otherwise " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$\\begin{cases} - {L}_{0} & \\text{for}\\: yp_{0} < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}$$" + ], + "text/plain": [ + "⎧-L[0] for yp₀ < 9.9e+29\n", + "⎨ \n", + "⎩ 0 otherwise " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# For the general algorithm, the input parameters for the boundary conditions are\n", + "# - first derivative, if value is less than cutoff\n", + "# - second derivative is zero, if vlaue is greater than cutoff\n", + "\n", + "bc_cutoff = 0.99e30\n", + "\n", + "tbc_start_coeff1 = Piecewise((bc_firstd_start_coeff1, yp0 < bc_cutoff),(bc_natural_start_coeff1,True))\n", + "display(tbc_start_coeff1)\n", + "tbc_start_coeff2 = Piecewise((bc_firstd_start_coeff2, yp0 < bc_cutoff),(bc_natural_start_coeff2,True))\n", + "display(tbc_start_coeff2)\n", + "\n", + "sym_bc_start_coeff1 = Symbol('bc_start1')\n", + "sym_bc_start_coeff2 = Symbol('bc_start2')\n", + "bc_eqs = [Eq(sym_bc_start_coeff1, tbc_start_coeff1)]\n", + "bc_eqs.append(Eq(sym_bc_start_coeff2, tbc_start_coeff2))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\frac{\\left({E}_{n - 1} - {E}_{n - 2}\\right) {L}_{n - 2}}{2} + \\frac{- \\frac{\\left({E}_{n - 1} + 2 {E}_{n - 2}\\right) {L}_{n - 2}^{2}}{6} + {y}_{n - 1} - {y}_{n - 2}}{{L}_{n - 2}} + {E}_{n - 2} {L}_{n - 2} = ypn$$" + ], + "text/plain": [ + " 2 \n", + " (E[n - 1] + 2⋅E[n - 2])⋅L[n - 2] \n", + " - ───────────────────────────────── + y[n - 1\n", + "(E[n - 1] - E[n - 2])⋅L[n - 2] 6 \n", + "────────────────────────────── + ─────────────────────────────────────────────\n", + " 2 L[n - 2] \n", + "\n", + " \n", + " \n", + "] - y[n - 2] \n", + " \n", + "──────────── + E[n - 2]⋅L[n - 2] = ypn\n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$\\frac{{E}_{n - 1} {L}_{n - 2}}{3} + \\frac{{E}_{n - 2} {L}_{n - 2}}{6} = ypn - \\frac{{y}_{n - 1}}{{L}_{n - 2}} + \\frac{{y}_{n - 2}}{{L}_{n - 2}}$$" + ], + "text/plain": [ + "E[n - 1]⋅L[n - 2] E[n - 2]⋅L[n - 2] y[n - 1] y[n - 2]\n", + "───────────────── + ───────────────── = ypn - ──────── + ────────\n", + " 3 6 L[n - 2] L[n - 2]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$2 {E}_{n - 1} {L}_{n - 2} + {E}_{n - 2} {L}_{n - 2} = 6 ypn - \\frac{6 {y}_{n - 1}}{{L}_{n - 2}} + \\frac{6 {y}_{n - 2}}{{L}_{n - 2}}$$" + ], + "text/plain": [ + " 6⋅y[n - 1] 6⋅y[n - 2]\n", + "2⋅E[n - 1]⋅L[n - 2] + E[n - 2]⋅L[n - 2] = 6⋅ypn - ────────── + ──────────\n", + " L[n - 2] L[n - 2] " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${L}_{n - 2}$$" + ], + "text/plain": [ + "L[n - 2]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$2 {L}_{n - 2}$$" + ], + "text/plain": [ + "2⋅L[n - 2]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# BC - first derivative specified at the end of the range\n", + "ypn = Symbol('ypn')\n", + "eqbc2=Eq(diff(si,t).subs(t,L[end-1]).subs(sln).subs(i,end-1),ypn)\n", + "display(eqbc2)\n", + "eqbc2b = divide_terms(expand(eqbc2),[E[end-1],E[end]],[y[end-1],y[end],ypn])\n", + "display(eqbc2b)\n", + "eqbc2c = mult_eqn(eqbc2b, 6)\n", + "display(eqbc2c)\n", + "bc_firstd_end_coeff1 = get_coeff_for(eqbc2c.lhs, E[end-1],[E[end-1],E[end]])\n", + "display(bc_firstd_end_coeff1)\n", + "bc_firstd_end_coeff2 = get_coeff_for(eqbc2c.lhs, E[end],[E[end-1],E[end]])\n", + "display(bc_firstd_end_coeff2)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\begin{cases} {L}_{n - 2} & \\text{for}\\: ypn < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}$$" + ], + "text/plain": [ + "⎧L[n - 2] for ypn < 9.9e+29\n", + "⎨ \n", + "⎩ 0 otherwise " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$\\begin{cases} 2 {L}_{n - 2} & \\text{for}\\: ypn < 9.9 \\cdot 10^{29} \\\\1 & \\text{otherwise} \\end{cases}$$" + ], + "text/plain": [ + "⎧2⋅L[n - 2] for ypn < 9.9e+29\n", + "⎨ \n", + "⎩ 1 otherwise " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Create the conditional expression for the end BC\n", + "tbc_end_coeff1 = Piecewise((bc_firstd_end_coeff1, ypn < bc_cutoff),(bc_natural_end_coeff1, True))\n", + "display(tbc_end_coeff1)\n", + "sym_bc_end_coeff1 = Symbol('bc_end1')\n", + "bc_eqs.append(Eq(sym_bc_end_coeff1, tbc_end_coeff1))\n", + "tbc_end_coeff2 = Piecewise((bc_firstd_end_coeff2, ypn < bc_cutoff),(bc_natural_end_coeff2, True))\n", + "tbc_end_coeff2\n", + "display(tbc_end_coeff2)\n", + "sym_bc_end_coeff2 = Symbol('bc_end2')\n", + "bc_eqs.append(Eq(sym_bc_end_coeff2, tbc_end_coeff2))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\begin{cases} 6 yp_{0} + \\frac{6 {y}_{0}}{{L}_{0}} - \\frac{6 {y}_{1}}{{L}_{0}} & \\text{for}\\: yp_{0} < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}$$" + ], + "text/plain": [ + "⎧ 6⋅y[0] 6⋅y[1] \n", + "⎪6⋅yp₀ + ────── - ────── for yp₀ < 9.9e+29\n", + "⎨ L[0] L[0] \n", + "⎪ \n", + "⎩ 0 otherwise " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$\\begin{cases} 6 ypn - \\frac{6 {y}_{n - 1}}{{L}_{n - 2}} + \\frac{6 {y}_{n - 2}}{{L}_{n - 2}} & \\text{for}\\: ypn < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}$$" + ], + "text/plain": [ + "⎧ 6⋅y[n - 1] 6⋅y[n - 2] \n", + "⎪6⋅ypn - ────────── + ────────── for ypn < 9.9e+29\n", + "⎨ L[n - 2] L[n - 2] \n", + "⎪ \n", + "⎩ 0 otherwise " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$\\left [ bc_{start1} = \\begin{cases} - 2 {L}_{0} & \\text{for}\\: yp_{0} < 9.9 \\cdot 10^{29} \\\\1 & \\text{otherwise} \\end{cases}, \\quad bc_{start2} = \\begin{cases} - {L}_{0} & \\text{for}\\: yp_{0} < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}, \\quad bc_{end1} = \\begin{cases} {L}_{n - 2} & \\text{for}\\: ypn < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}, \\quad bc_{end2} = \\begin{cases} 2 {L}_{n - 2} & \\text{for}\\: ypn < 9.9 \\cdot 10^{29} \\\\1 & \\text{otherwise} \\end{cases}, \\quad rhs_{start} = \\begin{cases} 6 yp_{0} + \\frac{6 {y}_{0}}{{L}_{0}} - \\frac{6 {y}_{1}}{{L}_{0}} & \\text{for}\\: yp_{0} < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}, \\quad rhs_{end} = \\begin{cases} 6 ypn - \\frac{6 {y}_{n - 1}}{{L}_{n - 2}} + \\frac{6 {y}_{n - 2}}{{L}_{n - 2}} & \\text{for}\\: ypn < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}\\right ]$$" + ], + "text/plain": [ + "⎡ \n", + "⎢ ⎧-2⋅L[0] for yp₀ < 9.9e+29 ⎧-L[0] for yp₀ < 9.9e+29 \n", + "⎢bcₛₜₐᵣₜ₁ = ⎨ , bcₛₜₐᵣₜ₂ = ⎨ ,\n", + "⎢ ⎩ 1 otherwise ⎩ 0 otherwise \n", + "⎣ \n", + "\n", + " \n", + " ⎧L[n - 2] for ypn < 9.9e+29 ⎧2⋅L[n - 2] for ypn < 9.9e\n", + " bc_end1 = ⎨ , bc_end2 = ⎨ \n", + " ⎩ 0 otherwise ⎩ 1 otherwise \n", + " \n", + "\n", + " ⎧ 6⋅y[0] 6⋅y[1] ⎧ \n", + "+29 ⎪6⋅yp₀ + ────── - ────── for yp₀ < 9.9e+29 ⎪6⋅ypn \n", + " , rhsₛₜₐᵣₜ = ⎨ L[0] L[0] , rhs_end = ⎨ \n", + " ⎪ ⎪ \n", + " ⎩ 0 otherwise ⎩ \n", + "\n", + " 6⋅y[n - 1] 6⋅y[n - 2] ⎤\n", + "- ────────── + ────────── for ypn < 9.9e+29⎥\n", + " L[n - 2] L[n - 2] ⎥\n", + " ⎥\n", + " 0 otherwise ⎦" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# conditional expressions for RHS for boundary conditions\n", + "rhs_start = Piecewise((eqbc1c.rhs,yp0 < bc_cutoff),(bc_natural_start.rhs,True))\n", + "display(rhs_start)\n", + "rhs_end = Piecewise((eqbc2c.rhs, ypn < bc_cutoff), (bc_natural_end.rhs, True))\n", + "display(rhs_end)\n", + "\n", + "sym_rhs_start = Symbol('rhs_start')\n", + "sym_rhs_end = Symbol('rhs_end')\n", + "bc_eqs.append(Eq(sym_rhs_start, rhs_start))\n", + "bc_eqs.append(Eq(sym_rhs_end, rhs_end))\n", + "bc_eqs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ### Substitute cubic spline equations into tridiagonal solver" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\left \\{ {a}_{0} : 0, \\quad {a}_{i} : {L}_{i - 1}, \\quad {a}_{n - 1} : bc_{end1}, \\quad {b}_{0} : bc_{start1}, \\quad {b}_{i} : 2 {L}_{i - 1} + 2 {L}_{i}, \\quad {b}_{n - 1} : bc_{end2}, \\quad {c}_{0} : bc_{start2}, \\quad {c}_{i} : {L}_{i}, \\quad {c}_{n - 1} : 0, \\quad {d}_{0} : rhs_{start}, \\quad {d}_{i} : \\frac{6 {y}_{i + 1}}{{L}_{i}} - \\frac{6 {y}_{i}}{{L}_{i}} + \\frac{6 {y}_{i - 1}}{{L}_{i - 1}} - \\frac{6 {y}_{i}}{{L}_{i - 1}}, \\quad {d}_{n - 1} : rhs_{end}\\right \\}$$" + ], + "text/plain": [ + "⎧ \n", + "⎨a[0]: 0, a[i]: L[i - 1], a[n - 1]: bc_end1, b[0]: bcₛₜₐᵣₜ₁, b[i]: 2⋅L[i - 1] \n", + "⎩ \n", + "\n", + " \n", + "+ 2⋅L[i], b[n - 1]: bc_end2, c[0]: bcₛₜₐᵣₜ₂, c[i]: L[i], c[n - 1]: 0, d[0]: rh\n", + " \n", + "\n", + " 6⋅y[i + 1] 6⋅y[i] 6⋅y[i - 1] 6⋅y[i] ⎫\n", + "sₛₜₐᵣₜ, d[i]: ────────── - ────── + ────────── - ────────, d[n - 1]: rhs_end⎬\n", + " L[i] L[i] L[i - 1] L[i - 1] ⎭" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "subslist = {\n", + " a[start] : 0,\n", + " a[i] : coeff1,\n", + " a[end] : sym_bc_end_coeff1,\n", + " \n", + " b[start] : sym_bc_start_coeff1,\n", + " b[i] : coeff2,\n", + " b[end] : sym_bc_end_coeff2,\n", + " \n", + " c[start] : sym_bc_start_coeff2,\n", + " c[i] : coeff3,\n", + " c[end] : 0,\n", + " \n", + " d[start] : sym_rhs_start,\n", + " d[i] : sp9.rhs,\n", + " d[end] : sym_rhs_end,\n", + "}\n", + "\n", + "# Replace knot spacing with differences bewteen knot locations\n", + "subsL = {\n", + " L[i] : x[i+1] - x[i],\n", + " L[i+1] : x[i+2] - x[i+1],\n", + " L[i-1] : x[i] - x[i-1],\n", + " L[start] : x[start+1]-x[start],\n", + " L[start+1] : x[start+2]-x[start+1],\n", + " L[end-1] : x[end] - x[end-1],\n", + "}\n", + "subslist" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${c'}_{0} = \\frac{bc_{start2}}{bc_{start1}}$$" + ], + "text/plain": [ + " bcₛₜₐᵣₜ₂\n", + "c'[0] = ────────\n", + " bcₛₜₐᵣₜ₁" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${d'}_{0} = \\frac{rhs_{start}}{bc_{start1}}$$" + ], + "text/plain": [ + " rhsₛₜₐᵣₜ\n", + "d'[0] = ────────\n", + " bcₛₜₐᵣₜ₁" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${d'}_{i} = \\frac{- \\left({x}_{i + 1} - {x}_{i}\\right) \\left({x}_{i - 1} - {x}_{i}\\right)^{2} {d'}_{i - 1} + 6 \\left({x}_{i + 1} - {x}_{i}\\right) \\left({y}_{i - 1} - {y}_{i}\\right) + 6 \\left({x}_{i - 1} - {x}_{i}\\right) \\left(- {y}_{i + 1} + {y}_{i}\\right)}{\\left({x}_{i + 1} - {x}_{i}\\right) \\left({x}_{i - 1} - {x}_{i}\\right) \\left(- \\left({x}_{i - 1} - {x}_{i}\\right) {c'}_{i - 1} - 2 {x}_{i + 1} + 2 {x}_{i - 1}\\right)}$$" + ], + "text/plain": [ + " 2 \n", + " - (x[i + 1] - x[i])⋅(x[i - 1] - x[i]) ⋅d'[i - 1] + 6⋅(x[i + 1] - x[i])\n", + "d'[i] = ──────────────────────────────────────────────────────────────────────\n", + " (x[i + 1] - x[i])⋅(x[i - 1] - x[i])⋅(-(x[i - 1] - x\n", + "\n", + " \n", + "⋅(y[i - 1] - y[i]) + 6⋅(x[i - 1] - x[i])⋅(-y[i + 1] + y[i])\n", + "───────────────────────────────────────────────────────────\n", + "[i])⋅c'[i - 1] - 2⋅x[i + 1] + 2⋅x[i - 1]) " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${c'}_{i} = \\frac{{x}_{i + 1} - {x}_{i}}{- \\left(- {x}_{i - 1} + {x}_{i}\\right) {c'}_{i - 1} + 2 {x}_{i + 1} - 2 {x}_{i - 1}}$$" + ], + "text/plain": [ + " x[i + 1] - x[i] \n", + "c'[i] = ───────────────────────────────────────────────────────\n", + " -(-x[i - 1] + x[i])⋅c'[i - 1] + 2⋅x[i + 1] - 2⋅x[i - 1]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${x}_{n - 1} = \\frac{- bc_{end1} {d'}_{n - 2} + rhs_{end}}{- bc_{end1} {c'}_{n - 2} + bc_{end2}}$$" + ], + "text/plain": [ + " -bc_end1⋅d'[n - 2] + rhs_end\n", + "x[n - 1] = ────────────────────────────\n", + " -bc_end1⋅c'[n - 2] + bc_end2" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${x}_{i} = - {c'}_{i} {x}_{i + 1} + {d'}_{i}$$" + ], + "text/plain": [ + "x[i] = -c'[i]⋅x[i + 1] + d'[i]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Substitute into the tridiagonal solver\n", + "display(teq1.subs(subslist))\n", + "teq2b = teq2.subs(subslist).subs(subsL)\n", + "display(teq2b)\n", + "teq3b = simplify(teq3.subs(subslist).subs(subsL))\n", + "display(teq3b)\n", + "teq4b = teq4.subs(subslist).subs(subsL)\n", + "display(teq4b)\n", + "teq5b = Eq(teq5.lhs,teq5.rhs.subs(dp[end],teq3.rhs).subs(i,end).subs(subslist))\n", + "display(teq5b)\n", + "display(teq6.subs(subslist))" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\left [ \\left ( z_{0}, \\quad - {x}_{i}\\right ), \\quad \\left ( z_{1}, \\quad z_{0} + {x}_{i + 1}\\right ), \\quad \\left ( z_{2}, \\quad z_{0} + {x}_{i - 1}\\right ), \\quad \\left ( z_{3}, \\quad 2 {x}_{i + 1}\\right ), \\quad \\left ( z_{4}, \\quad 2 {x}_{i - 1}\\right ), \\quad \\left ( z_{5}, \\quad z_{2} {c'}_{i - 1}\\right ), \\quad \\left ( z_{6}, \\quad - {y}_{i}\\right )\\right ]$$" + ], + "text/plain": [ + "[(z₀, -x[i]), (z₁, z₀ + x[i + 1]), (z₂, z₀ + x[i - 1]), (z₃, 2⋅x[i + 1]), (z₄,\n", + " 2⋅x[i - 1]), (z₅, z₂⋅c'[i - 1]), (z₆, -y[i])]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$$\\left [ {d'}_{i} = \\frac{z_{1} z_{2}^{2} {d'}_{i - 1} - 6 z_{1} \\left(z_{6} + {y}_{i - 1}\\right) + 6 z_{2} \\left(z_{6} + {y}_{i + 1}\\right)}{z_{1} z_{2} \\left(z_{3} - z_{4} + z_{5}\\right)}, \\quad {c'}_{i} = \\frac{- {x}_{i + 1} + {x}_{i}}{- z_{3} + z_{4} - z_{5}}\\right ]$$" + ], + "text/plain": [ + "⎡ 2 \n", + "⎢ z₁⋅z₂ ⋅d'[i - 1] - 6⋅z₁⋅(z₆ + y[i - 1]) + 6⋅z₂⋅(z₆ + y[i + 1]) \n", + "⎢d'[i] = ──────────────────────────────────────────────────────────────, c'[i]\n", + "⎣ z₁⋅z₂⋅(z₃ - z₄ + z₅) \n", + "\n", + " ⎤\n", + " -x[i + 1] + x[i]⎥\n", + " = ────────────────⎥\n", + " -z₃ + z₄ - z₅ ⎦" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Extract sub-expressions\n", + "subexpr, final_expr = cse([simplify(teq3b),simplify(teq4b)],symbols=numbered_symbols('z'))\n", + "display(subexpr)\n", + "display(final_expr)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$$\\left [ bc_{start1} = \\begin{cases} 2 {x}_{0} - 2 {x}_{1} & \\text{for}\\: yp_{0} < 9.9 \\cdot 10^{29} \\\\1 & \\text{otherwise} \\end{cases}, \\quad bc_{start2} = \\begin{cases} {x}_{0} - {x}_{1} & \\text{for}\\: yp_{0} < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}, \\quad bc_{end1} = \\begin{cases} {x}_{n - 1} - {x}_{n - 2} & \\text{for}\\: ypn < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}, \\quad bc_{end2} = \\begin{cases} 2 {x}_{n - 1} - 2 {x}_{n - 2} & \\text{for}\\: ypn < 9.9 \\cdot 10^{29} \\\\1 & \\text{otherwise} \\end{cases}, \\quad rhs_{start} = \\begin{cases} 6 yp_{0} + \\frac{6 {y}_{0}}{- {x}_{0} + {x}_{1}} - \\frac{6 {y}_{1}}{- {x}_{0} + {x}_{1}} & \\text{for}\\: yp_{0} < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}, \\quad rhs_{end} = \\begin{cases} 6 ypn - \\frac{6 {y}_{n - 1}}{{x}_{n - 1} - {x}_{n - 2}} + \\frac{6 {y}_{n - 2}}{{x}_{n - 1} - {x}_{n - 2}} & \\text{for}\\: ypn < 9.9 \\cdot 10^{29} \\\\0 & \\text{otherwise} \\end{cases}\\right ]$$" + ], + "text/plain": [ + "⎡ \n", + "⎢ ⎧2⋅x[0] - 2⋅x[1] for yp₀ < 9.9e+29 ⎧x[0] - x[1] for \n", + "⎢bcₛₜₐᵣₜ₁ = ⎨ , bcₛₜₐᵣₜ₂ = ⎨ \n", + "⎢ ⎩ 1 otherwise ⎩ 0 \n", + "⎣ \n", + "\n", + " \n", + "yp₀ < 9.9e+29 ⎧x[n - 1] - x[n - 2] for ypn < 9.9e+29 ⎧2\n", + " , bc_end1 = ⎨ , bc_end2 = ⎨ \n", + "otherwise ⎩ 0 otherwise ⎩ \n", + " \n", + "\n", + " ⎧ 6⋅y[0] \n", + "⋅x[n - 1] - 2⋅x[n - 2] for ypn < 9.9e+29 ⎪6⋅yp₀ + ──────────── - \n", + " , rhsₛₜₐᵣₜ = ⎨ -x[0] + x[1] \n", + " 1 otherwise ⎪ \n", + " ⎩ 0 \n", + "\n", + " 6⋅y[1] ⎧ 6⋅y[n - 1] \n", + "──────────── for yp₀ < 9.9e+29 ⎪6⋅ypn - ─────────────────── + ────\n", + "-x[0] + x[1] , rhs_end = ⎨ x[n - 1] - x[n - 2] x[n \n", + " ⎪ \n", + " otherwise ⎩ 0 \n", + "\n", + " 6⋅y[n - 2] ⎤\n", + "─────────────── for ypn < 9.9e+29⎥\n", + "- 1] - x[n - 2] ⎥\n", + " ⎥\n", + " otherwise ⎦" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Substitute knot spacing into the boundary conditions\n", + "bc_eqs2 = [eq.subs(subsL) for eq in bc_eqs]\n", + "bc_eqs2" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$${y_{2}}_{0} = \\frac{bc_{start2}}{bc_{start1}}$$" + ], + "text/plain": [ + " bcₛₜₐᵣₜ₂\n", + "y2[0] = ────────\n", + " bcₛₜₐᵣₜ₁" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${u}_{0} = \\frac{rhs_{start}}{bc_{start1}}$$" + ], + "text/plain": [ + " rhsₛₜₐᵣₜ\n", + "u[0] = ────────\n", + " bcₛₜₐᵣₜ₁" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${u}_{i} = \\frac{z_{1} z_{2}^{2} {u}_{i - 1} - 6 z_{1} \\left(z_{6} + {y}_{i - 1}\\right) + 6 z_{2} \\left(z_{6} + {y}_{i + 1}\\right)}{z_{1} z_{2} \\left(z_{3} - z_{4} + z_{5}\\right)}$$" + ], + "text/plain": [ + " 2 \n", + " z₁⋅z₂ ⋅u[i - 1] - 6⋅z₁⋅(z₆ + y[i - 1]) + 6⋅z₂⋅(z₆ + y[i + 1])\n", + "u[i] = ─────────────────────────────────────────────────────────────\n", + " z₁⋅z₂⋅(z₃ - z₄ + z₅) " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${y_{2}}_{i} = \\frac{- {x}_{i + 1} + {x}_{i}}{- z_{3} + z_{4} - z_{5}}$$" + ], + "text/plain": [ + " -x[i + 1] + x[i]\n", + "y2[i] = ────────────────\n", + " -z₃ + z₄ - z₅ " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${y_{2}}_{n - 1} = \\frac{- bc_{end1} {u}_{n - 2} + rhs_{end}}{- bc_{end1} {y_{2}}_{n - 2} + bc_{end2}}$$" + ], + "text/plain": [ + " -bc_end1⋅u[n - 2] + rhs_end \n", + "y2[n - 1] = ────────────────────────────\n", + " -bc_end1⋅y2[n - 2] + bc_end2" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$${y_{2}}_{i} = {u}_{i} - {y_{2}}_{i + 1} {y_{2}}_{i}$$" + ], + "text/plain": [ + "y2[i] = u[i] - y2[i + 1]⋅y2[i]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Use temporary storage for cp, and reuse output vector for dp\n", + "# In the future there should be some dependency analysis to verify this is a legal transformation\n", + "tmp = IndexedBase('u',shape=(n,))\n", + "y2 = IndexedBase('y2',shape=(n,))\n", + "storage_subs = {cp:y2, dp:tmp}\n", + "#storage_subs = {}\n", + "teq1c = teq1.subs(subslist).subs(storage_subs)\n", + "display(teq1c)\n", + "teq2c = teq2b.subs(subslist).subs(storage_subs)\n", + "display(teq2c)\n", + "teq3c = final_expr[0].subs(storage_subs)\n", + "display(teq3c)\n", + "teq4c = final_expr[1].subs(storage_subs)\n", + "display(teq4c)\n", + "teq5c = teq5b.subs(storage_subs).subs(x,y2)\n", + "display(teq5c)\n", + "teq6c = teq6.subs(storage_subs).subs(x,y2)\n", + "display(teq6c)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "# Now for some code generation\n", + "#reload(codegen_more)\n", + "#from codegen_more import *" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "templateT = Type('T')" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "# forward sweep\n", + "fr = ARange(start+1,end,1)\n", + "\n", + "body = []\n", + "for e in subexpr:\n", + " body.append(Variable(e[0],type=templateT).as_Declaration(value=e[1].subs(storage_subs)))\n", + " \n", + "body.append(convert_eq_to_assignment(teq3c))\n", + "body.append(convert_eq_to_assignment(teq4c))\n", + "loop1 = For(i,fr,body)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "# backward sweep\n", + "br = ARangeClosedEnd(end-1,start,-1)\n", + "loop2 = For(i,br,[convert_eq_to_assignment(teq6c)])" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "tmp_init = VariableWithInit(\"n\",tmp,type=Type(\"std::vector\")).as_Declaration()\n", + "bc_tmps = []\n", + "for e in bc_eqs2:\n", + " bc_tmps.append(Variable(e.lhs, type=templateT).as_Declaration(value=e.rhs))\n", + "algo = CodeBlock(tmp_init,\n", + " *bc_tmps,\n", + " convert_eq_to_assignment(teq1c),\n", + " convert_eq_to_assignment(teq2c),\n", + " loop1,\n", + " convert_eq_to_assignment(teq5c),\n", + " loop2)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "// Not supported in C++:\n", + "// IndexedBase\n", + "std::vector u(n);\n", + "T bc_start1 = ((yp0 < 9.9000000000000002e+29) ? (\n", + " 2*x[0] - 2*x[1]\n", + ")\n", + ": (\n", + " 1\n", + "));\n", + "T bc_start2 = ((yp0 < 9.9000000000000002e+29) ? (\n", + " x[0] - x[1]\n", + ")\n", + ": (\n", + " 0\n", + "));\n", + "T bc_end1 = ((ypn < 9.9000000000000002e+29) ? (\n", + " x[n - 1] - x[n - 2]\n", + ")\n", + ": (\n", + " 0\n", + "));\n", + "T bc_end2 = ((ypn < 9.9000000000000002e+29) ? (\n", + " 2*x[n - 1] - 2*x[n - 2]\n", + ")\n", + ": (\n", + " 1\n", + "));\n", + "T rhs_start = ((yp0 < 9.9000000000000002e+29) ? (\n", + " 6*yp0 + 6*y[0]/(-x[0] + x[1]) - 6*y[1]/(-x[0] + x[1])\n", + ")\n", + ": (\n", + " 0\n", + "));\n", + "T rhs_end = ((ypn < 9.9000000000000002e+29) ? (\n", + " 6*ypn - 6*y[n - 1]/(x[n - 1] - x[n - 2]) + 6*y[n - 2]/(x[n - 1] - x[n - 2])\n", + ")\n", + ": (\n", + " 0\n", + "));\n", + "y2[0] = bc_start2/bc_start1;\n", + "u[0] = rhs_start/bc_start1;\n", + "for (auto i = 1; i < n - 1; i += 1) {\n", + " T z0 = -x[i];\n", + " T z1 = z0 + x[i + 1];\n", + " T z2 = z0 + x[i - 1];\n", + " T z3 = 2*x[i + 1];\n", + " T z4 = 2*x[i - 1];\n", + " T z5 = z2*y2[i - 1];\n", + " T z6 = -y[i];\n", + " u[i] = (z1*z2*z2*u[i - 1] - 6*z1*(z6 + y[i - 1]) + 6*z2*(z6 + y[i + 1]))/(z1*z2*(z3 - z4 + z5));\n", + " y2[i] = (-x[i + 1] + x[i])/(-z3 + z4 - z5);\n", + "};\n", + "y2[n - 1] = (-bc_end1*u[n - 2] + rhs_end)/(-bc_end1*y2[n - 2] + bc_end2);\n", + "for (auto i = n - 2; i >= 0; i += -1) {\n", + " y2[i] = u[i] - y2[i + 1]*y2[i];\n", + "};\n" + ] + } + ], + "source": [ + "# Generate the inner part of the algorithm to check it\n", + "ACP = ACodePrinter()\n", + "s = ACP.doprint(algo)\n", + "print(s)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up to create a template function\n", + "tx = Pointer(x,type=templateT)\n", + "ty = Pointer(y,type=templateT)\n", + "ty2 = Pointer(y2,type=templateT)\n", + "yp0_var = Variable('yp0',type=templateT)\n", + "ypn_var = Variable('ypn',type=templateT)\n", + "\n", + "tf = TemplateFunctionDefinition(void, \"cubic_spline_solve\",[tx,ty,n,yp0_var,ypn_var,ty2],[templateT],algo)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "// Not supported in C++:\n", + "// IndexedBase\n", + "// IndexedBase\n", + "// IndexedBase\n", + "// IndexedBase\n", + "template\n", + "void cubic_spline_solve(T * x, T * y, int n, T yp0, T ypn, T * y2){\n", + " std::vector u(n);\n", + " T bc_start1 = ((yp0 < 9.9000000000000002e+29) ? (\n", + " 2*x[0] - 2*x[1]\n", + " )\n", + " : (\n", + " 1\n", + " ));\n", + " T bc_start2 = ((yp0 < 9.9000000000000002e+29) ? (\n", + " x[0] - x[1]\n", + " )\n", + " : (\n", + " 0\n", + " ));\n", + " T bc_end1 = ((ypn < 9.9000000000000002e+29) ? (\n", + " x[n - 1] - x[n - 2]\n", + " )\n", + " : (\n", + " 0\n", + " ));\n", + " T bc_end2 = ((ypn < 9.9000000000000002e+29) ? (\n", + " 2*x[n - 1] - 2*x[n - 2]\n", + " )\n", + " : (\n", + " 1\n", + " ));\n", + " T rhs_start = ((yp0 < 9.9000000000000002e+29) ? (\n", + " 6*yp0 + 6*y[0]/(-x[0] + x[1]) - 6*y[1]/(-x[0] + x[1])\n", + " )\n", + " : (\n", + " 0\n", + " ));\n", + " T rhs_end = ((ypn < 9.9000000000000002e+29) ? (\n", + " 6*ypn - 6*y[n - 1]/(x[n - 1] - x[n - 2]) + 6*y[n - 2]/(x[n - 1] - x[n - 2])\n", + " )\n", + " : (\n", + " 0\n", + " ));\n", + " y2[0] = bc_start2/bc_start1;\n", + " u[0] = rhs_start/bc_start1;\n", + " for (auto i = 1; i < n - 1; i += 1) {\n", + " T z0 = -x[i];\n", + " T z1 = z0 + x[i + 1];\n", + " T z2 = z0 + x[i - 1];\n", + " T z3 = 2*x[i + 1];\n", + " T z4 = 2*x[i - 1];\n", + " T z5 = z2*y2[i - 1];\n", + " T z6 = -y[i];\n", + " u[i] = (z1*z2*z2*u[i - 1] - 6*z1*(z6 + y[i - 1]) + 6*z2*(z6 + y[i + 1]))/(z1*z2*(z3 - z4 + z5));\n", + " y2[i] = (-x[i + 1] + x[i])/(-z3 + z4 - z5);\n", + " };\n", + " y2[n - 1] = (-bc_end1*u[n - 2] + rhs_end)/(-bc_end1*y2[n - 2] + bc_end2);\n", + " for (auto i = n - 2; i >= 0; i += -1) {\n", + " y2[i] = u[i] - y2[i + 1]*y2[i];\n", + " };\n", + "}\n" + ] + } + ], + "source": [ + "ACP = ACodePrinter()\n", + "s = ACP.doprint(tf)\n", + "print(s)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Wavefunctions/codegen_extras.py b/Wavefunctions/codegen_extras.py new file mode 100644 index 0000000..4420d01 --- /dev/null +++ b/Wavefunctions/codegen_extras.py @@ -0,0 +1,155 @@ +from __future__ import print_function + +# Extensions to Sympy code generation to support C++ with templates and more + +from sympy import Set, Basic, Tuple, IndexedBase +from sympy.codegen.ast import Assignment, Pointer, Node, Type +from sympy.codegen.ast import String, Declaration, Variable +from sympy.printing.cxxcode import CXX11CodePrinter + +# A range class that accepts symbolic limits. Purpose is For loops +class ARange(Set): + is_iterable = True + + def __new__(cls, *args): + slc = slice(*args) + start = slc.start + stop = slc.stop + step = slc.step + + return Basic.__new__(cls, start, stop, step) + start = property(lambda self: self.args[0]) + stop = property(lambda self: self.args[1]) + step = property(lambda self: self.args[2]) + + # Just here to make it pass the 'iterable' test + def __iter__(self): + i = 0 + yield i + +# The end value should be included in the iteration +class ARangeClosedEnd(ARange): + pass + + +# Convert Eq to Assignment +def convert_eq_to_assignment(expr): + return Assignment(expr.lhs, expr.rhs) + + +# Node for a C++ reference +class Reference(Pointer): + """ Represents a C++ reference""" + pass + + +# Specify direct initialization (with parentheses) +# e.g. int j(0); +class VariableWithInit(Variable): + + __slots__ = ['type_init'] + Variable.__slots__ + + +# Templated function definition +class TemplateFunctionDefinition(Node): + __slots__ = ['return_type','name','parameters','template_types','body','attrs'] + _construct_return_type = Type + _construct_name = String + + @staticmethod + def _construct_parameters(args): + def _var(arg): + if isinstance(arg, Declaration): + return arg.variable + elif isinstance(arg, Variable): + return arg + else: + return Variable.deduced(arg) + return Tuple(*map(_var, args)) + + @staticmethod + def _construct_template_types(args): + return Tuple(*args) + + +# Code printer for extended features +class ACodePrinter(CXX11CodePrinter): + def __init__(self, settings=None): + super(ACodePrinter, self).__init__(settings=settings) + + def _print_Assignment(self, expr): + lhs = expr.lhs + rhs = expr.rhs + if lhs.has(IndexedBase) or rhs.has(IndexedBase): + return self._get_statement("%s = %s"%(self._print(lhs),self._print(rhs))) + else: + return super(ACodePrinter, self)._print_Assignment(expr) + + def _print_Pow(self, expr): + if expr.exp == 2: + e = self._print(expr.base) + return '%s*%s'%(e,e) + return super(ACodePrinter, self)._print_Pow(expr) + + def _print_Symbol(self, expr): + name = super(ACodePrinter, self)._print_Symbol(expr) + # Replace prime marker in symbol name with something acceptable in C++ + # Maybe should generalize to a lookup from symbol name to code name? + if "'" in name: + name = name.replace("'", "p") + return name + + def _print_Declaration(self, decl): + #print("decl = ",decl,type(decl)) + var = decl.variable + val = var.value + if isinstance(var, Reference): + result = '{t}& {s}'.format( + t = self._print(var.type), + s = self._print(var.symbol) + ) + return result + elif isinstance(var, VariableWithInit): + result = '{t} {s}({init})'.format( + t=self._print(var.type), + s=self._print(var.symbol), + init=self._print(var.type_init)) + return result + else: + return super(ACodePrinter, self)._print_Declaration(decl) + + def _print_TemplateFunctionDefinition(self, expr): + decl = "template<{template_args}>\n{ret_type} {name}({params}){body}".format( + template_args=', '.join(map(lambda arg: 'typename '+self._print(arg), expr.template_types)), + ret_type=self._print(expr.return_type), + name=expr.name, + params=', '.join(map(lambda arg: self._print(Declaration(arg)), expr.parameters)), + body=self._print_Scope(expr) + + ) + return decl + + def _print_For(self, expr): + target = self._print(expr.target) + it = expr.iterable + body = self._print(expr.body) + #print("it = ",it,type(it),isinstance(it,ARange)) + + if isinstance(it, ARange): + end_compare = "" + if isinstance(it, ARangeClosedEnd): + end_compare="=" + if it.step > 0: + return ("for (auto {target} = {start}; {target} <{end_compare} {stop}; {target} += {step}) {{\n{body}\n}}").format( + target=target,start=it.start, stop=it.stop, step=it.step, body=body,end_compare=end_compare) + else: + return ("for (auto {target} = {start}; {target} >{end_compare} {stop}; {target} += {step}) {{\n{body}\n}}").format( + target=target,start=it.start, stop=it.stop, step=it.step, body=body, end_compare=end_compare) + else: + return super(ACodePrinter, self)._print_For(expr) + + + + + +