-
Notifications
You must be signed in to change notification settings - Fork 0
/
visu_brax.py
executable file
·149 lines (124 loc) · 3.92 KB
/
visu_brax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import jax.numpy as jnp
import numpy as np
import jax
import os
from brax.io import html
from brax.io.file import File
from brax.io.json import dumps
from IPython.display import HTML
from qdax.types import RNGKey
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire
from qdax.core.containers.mome_repertoire import MOMERepertoire
def save_samples(
env,
policy_network,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
num_save_visualisations: int,
iteration: str="final",
save_dir: str="./",
):
""" Select best individual and some random individuals from repertoire and visualise behaviour"""
number_individuals = len(repertoire.fitnesses)
# Visualise the best individual
best_idx = jnp.argmax(repertoire.fitnesses)
params = jax.tree_util.tree_map(
lambda x: x[best_idx],
repertoire.genotypes
)
visualise_individual(
env,
policy_network,
params,
f"best_iteration_{iteration}_individual_{best_idx}.html",
save_dir
)
# Visualise somne random individuals
random_indices = jax.random.randint(
random_key,
shape=(num_save_visualisations, ),
minval=0,
maxval=number_individuals
)
for index in random_indices:
params = jax.tree_util.tree_map(
lambda x: x[index],
repertoire.genotypes
)
visualise_individual(
env,
policy_network,
params,
f"iteration_{iteration}_individual_{index}.html",
save_dir
)
def save_mo_samples(
env,
policy_network,
random_key: RNGKey,
repertoire: MOMERepertoire,
num_save_visualisations: int,
iteration: str="final",
save_dir: str="./",
):
""" Select best individual and some random individuals from repertoire and visualise behaviour"""
number_individuals = len(repertoire.fitnesses)
# Visualise the best individual
best_genotypes, best_fitnesses = repertoire.get_best_individuals()
for index, genotype in enumerate(best_genotypes):
visualise_individual(
env,
policy_network,
genotype,
f"best_iteration_{iteration}_fitness_{best_fitnesses[index]}.html",
save_dir
)
# Visualise individuals from global pareto front
pf_genotypes, _ = repertoire.sample(random_key, num_save_visualisations)
for sample in range(num_save_visualisations):
params = jax.tree_util.tree_map(
lambda x: x[sample],
pf_genotypes
)
visualise_individual(
env,
policy_network,
params,
f"iteration_{iteration}_pf_sample_{sample}.html",
save_dir
)
# Sample random solutions from entire population
sampled_genotypes, _ = repertoire.sample(random_key, num_save_visualisations)
for sample in range(num_save_visualisations):
params = jax.tree_util.tree_map(
lambda x: x[sample],
sampled_genotypes
)
visualise_individual(
env,
policy_network,
params,
f"iteration_{iteration}_sample_{sample}.html",
save_dir
)
def visualise_individual(
env,
policy_network,
params,
name,
save_dir,
):
""" Roll out individual policy and save visualisation"""
path = os.path.join(save_dir, name)
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(policy_network.apply)
rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
while not state.done:
rollout.append(state)
action = jit_inference_fn(params, state.obs)
state = jit_env_step(state, action)
with File(path, 'w') as fout:
fout.write(html.render(env.sys, [s.qp for s in rollout], height=480))