-
Notifications
You must be signed in to change notification settings - Fork 1
/
pendulum.py
110 lines (88 loc) · 3.67 KB
/
pendulum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Adapted from https://colab.research.google.com/drive/1CSy-xfrnTX28p1difoTA8ulYw0zytJkq#scrollTo=srZU0YiAQ8rm
from functools import partial
from matplotlib.patches import Circle
import matplotlib.pyplot as plt
import numpy as onp
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
def lagrangian(q, q_dot, m1, m2, l1, l2, g):
t1, t2 = q # theta 1 and theta 2
w1, w2 = q_dot # omega 1 and omega 2
# kinetic energy (T)
T1 = 0.5 * m1 * (l1 * w1)**2
T2 = 0.5 * m2 * ((l1 * w1)**2 + (l2 * w2)**2 +
2 * l1 * l2 * w1 * w2 * jnp.cos(t1 - t2))
T = T1 + T2
# potential energy (V)
y1 = -l1 * jnp.cos(t1)
y2 = y1 - l2 * jnp.cos(t2)
V = m1 * g * y1 + m2 * g * y2
return T - V
def f_analytical(state, t=0, m1=1, m2=1, l1=1, l2=1, g=9.8):
t1, t2, w1, w2 = state
a1 = (l2 / l1) * (m2 / (m1 + m2)) * jnp.cos(t1 - t2)
a2 = (l1 / l2) * jnp.cos(t1 - t2)
f1 = -(l2 / l1) * (m2 / (m1 + m2)) * (w2**2) * jnp.sin(t1 - t2) - \
(g / l1) * jnp.sin(t1)
f2 = (l1 / l2) * (w1**2) * jnp.sin(t1 - t2) - (g / l2) * jnp.sin(t2)
g1 = (f1 - a1 * f2) / (1 - a1 * a2)
g2 = (f2 - a2 * f1) / (1 - a1 * a2)
return jnp.stack([w1, w2, g1, g2])
# Double pendulum dynamics via the rewritten Euler-Lagrange
@partial(jax.jit, backend='cpu')
def solve_autograd(initial_state, times, m1=1, m2=1, l1=1, l2=1, g=9.8):
L = partial(lagrangian, m1=m1, m2=m2, l1=l1, l2=l2, g=g)
return solve_lagrangian(L, initial_state, t=times)
# Double pendulum dynamics via analytical forces taken from Diego's blog
@partial(jax.jit, backend='cpu')
def solve_analytical(initial_state, times, m1=1, m2=1, l1=1, l2=1, g=9.8):
f_parametrised = partial(f_analytical, m1=m1, m2=m2, l1=l1, l2=l2, g=g)
return odeint(f_parametrised, initial_state, t=times)
def normalize_dp(state):
# wrap generalized coordinates to [-pi, pi]
return jnp.concatenate([(state[:2] + np.pi) % (2 * np.pi) - np.pi, state[2:]])
def rk4_step(f, x, t, h):
# one step of runge-kutta integration
k1 = h * f(x, t)
k2 = h * f(x + k1/2, t + h/2)
k3 = h * f(x + k2/2, t + h/2)
k4 = h * f(x + k3, t + h)
return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)
def make_plot(i, cart_coords, l1, l2, ax, max_trail=30, trail_segments=20, r=0.05):
# Plot and save an image of the double pendulum configuration for time step i.
plt.cla()
x1, y1, x2, y2 = cart_coords
ax.plot([0, x1[i], x2[i]], [0, y1[i], y2[i]], lw=2, c='k') # rods
c0 = Circle((0, 0), r/2, fc='k', zorder=10) # anchor point
c1 = Circle((x1[i], y1[i]), r, fc='b', ec='b', zorder=10) # mass 1
c2 = Circle((x2[i], y2[i]), r, fc='r', ec='r', zorder=10) # mass 2
ax.add_patch(c0)
ax.add_patch(c1)
ax.add_patch(c2)
# plot the pendulum trail (ns = number of segments)
s = max_trail // trail_segments
for j in range(trail_segments):
imin = i - (trail_segments-j)*s
if imin < 0: continue
imax = imin + s + 1
alpha = (j/trail_segments)**2 # fade the trail into alpha
ax.plot(x2[imin:imax], y2[imin:imax], c='r', solid_capstyle='butt',
lw=2, alpha=alpha)
# Center the image on the fixed anchor point. Make axes equal.
ax.set_xlim(-l1-l2-r, l1+l2+r)
ax.set_ylim(-l1-l2-r, l1+l2+r)
ax.set_aspect('equal', adjustable='box')
plt.axis('off')
def radial2cartesian(t1, t2, l1, l2):
# Convert from radial to Cartesian coordinates.
x1 = l1 * jnp.sin(t1)
y1 = -l1 * jnp.cos(t1)
x2 = x1 + l2 * jnp.sin(t2)
y2 = y1 - l2 * jnp.cos(t2)
return x1, y1, x2, y2
def fig2image(fig):
fig.canvas.draw()
data = onp.fromstring(fig.canvas.tostring_rgb(), dtype=onp.uint8, sep='')
image = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return image