-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmodels.py
74 lines (64 loc) · 2.22 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Tuple, Callable, Any
import jax.numpy as jnp
import haiku as hk
import jax
import jax.random as rnd
import numpy as np
Tensor = Any
PRNGKey = Any
Network = Callable[[hk.Params, PRNGKey, Tensor, bool], Tensor]
def get_model(
dim: int, batch_size: int, num_layers: int, hidden_size: int = 32, do_ev_noise=True,
) -> Tuple[hk.Params, Network]:
if do_ev_noise:
noise_dim = 1
else:
noise_dim = dim
l_dim = dim * (dim - 1) // 2
input_dim = l_dim + noise_dim
rng_key = rnd.PRNGKey(0)
def forward_fn(in_data: jnp.ndarray) -> jnp.ndarray:
# Must have num_heads * key_size (=64) = embedding_size
x = hk.Linear(hidden_size)(hk.Flatten()(in_data))
x = jax.nn.gelu(x)
for _ in range(num_layers - 2):
x = hk.Linear(hidden_size)(x)
x = jax.nn.gelu(x)
x = hk.Linear(hidden_size)(x)
x = jax.nn.gelu(x)
return hk.Linear(dim * dim)(x)
# out_stats = eval_mean(params, Xs, np.zeros(dim))
forward_fn_init, forward_fn_apply = hk.transform(forward_fn)
blank_data = np.zeros((batch_size, input_dim))
laplace_params = forward_fn_init(rng_key, blank_data)
return laplace_params, forward_fn_apply
def get_model_arrays(
dim: int,
batch_size: int,
num_layers: int,
rng_key: PRNGKey,
hidden_size: int = 32,
do_ev_noise=True,
) -> hk.Params:
"""Only returns parameters so that it can be used in pmap"""
if do_ev_noise:
noise_dim = 1
else:
noise_dim = dim
l_dim = dim * (dim - 1) // 2
input_dim = l_dim + noise_dim
def forward_fn(in_data: jnp.ndarray) -> jnp.ndarray:
# Must have num_heads * key_size (=64) = embedding_size
x = hk.Linear(hidden_size)(hk.Flatten()(in_data))
x = jax.nn.gelu(x)
for _ in range(num_layers - 2):
x = hk.Linear(hidden_size)(x)
x = jax.nn.gelu(x)
x = hk.Linear(hidden_size)(x)
x = jax.nn.gelu(x)
return hk.Linear(dim * dim)(x)
# out_stats = eval_mean(params, Xs, np.zeros(dim))
forward_fn_init, _ = hk.transform(forward_fn)
blank_data = np.zeros((batch_size, input_dim))
laplace_params = forward_fn_init(rng_key, blank_data)
return laplace_params