-
Notifications
You must be signed in to change notification settings - Fork 17
/
utils.py
63 lines (50 loc) · 1.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Symplectic ODE-Net | 2019
# Yaofeng Desmond Zhong, Biswadip Dey, Amit Chakraborty
# code structure follows the style of HNN by Sam Greydanus
# https://github.com/greydanus/hamiltonian-nn
import numpy as np
import os, torch, pickle, zipfile
import imageio, shutil
import scipy, scipy.misc, scipy.integrate
solve_ivp = scipy.integrate.solve_ivp
def L2_loss(u, v):
return (u-v).pow(2).mean()
def abs_loss(u, v):
return torch.abs(u-v).mean()
def to_pickle(thing, path): # save something
with open(path, 'wb') as handle:
pickle.dump(thing, handle, protocol=pickle.HIGHEST_PROTOCOL)
def from_pickle(path): # load something
thing = None
with open(path, 'rb') as handle:
thing = pickle.load(handle)
return thing
def choose_nonlinearity(name):
nl = None
if name == 'tanh':
nl = torch.tanh
elif name == 'relu':
nl = torch.relu
elif name == 'sigmoid':
nl = torch.sigmoid
elif name == 'softplus':
nl = torch.nn.functional.softplus
elif name == 'selu':
nl = torch.nn.functional.selu
elif name == 'elu':
nl = torch.nn.functional.elu
elif name == 'swish':
nl = lambda x: x * torch.sigmoid(x)
else:
raise ValueError("nonlinearity not recognized")
return nl
# from HNN
# for ablation study
def rk4(fun, y0, t, dt, *args, **kwargs):
dt2 = dt / 2.0
k1 = fun(y0, t, *args, **kwargs)
k2 = fun(y0 + dt2 * k1, t + dt2, *args, **kwargs)
k3 = fun(y0 + dt2 * k2, t + dt2, *args, **kwargs)
k4 = fun(y0 + dt * k3, t + dt, *args, **kwargs)
dy = dt / 6.0 * (k1 + 2 * k2 + 2 * k3 + k4)
return dy