-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_reproducibility.py
277 lines (225 loc) · 11 KB
/
main_reproducibility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import functools
from pathlib import Path
import time
import pickle
import jax
import jax.numpy as jnp
from flax import serialization
from qdax.environments import behavior_descriptor_extractor
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP, MLPDC
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from qdax.tasks.brax_envs import reset_based_scoring_actor_dc_function_brax_envs as scoring_actor_dc_function
import hydra
from omegaconf import OmegaConf, DictConfig
from utils import get_env, get_config, get_repertoire
ENV_MAX_EVALUATIONS = {
"ant_omni": 128,
"anttrap_omni": 128,
"humanoid_omni": 8,
"walker2d_uni": 64,
"halfcheetah_uni": 64,
"ant_uni": 64,
"humanoid_uni": 8,
}
def repertoire_reproduciblity(run_dir, num_evaluations):
# Get config
config = get_config(run_dir)
assert num_evaluations <= ENV_MAX_EVALUATIONS[config.env.name] or num_evaluations % ENV_MAX_EVALUATIONS[config.env.name] == 0
# Get repertoire
repertoire = get_repertoire(run_dir)
# Init a random key
random_key = jax.random.PRNGKey(config.seed)
# Init environment
env = get_env(config)
reset_fn = jax.jit(env.reset)
# Init policy network
policy_layer_sizes = config.policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
layer_sizes=policy_layer_sizes,
kernel_init=jax.nn.initializers.lecun_uniform(),
final_activation=jnp.tanh,
)
# Define the fonction to play a step with the policy in the environment
def play_step_fn(env_state, policy_params, random_key):
actions = policy_network.apply(policy_params, env_state.obs)
state_desc = env_state.info["state_descriptor"]
next_state = env.step(env_state, actions)
transition = QDTransition(
obs=env_state.obs,
next_obs=next_state.obs,
rewards=next_state.reward,
dones=next_state.done,
truncations=next_state.info["truncation"],
actions=actions,
state_desc=state_desc,
next_state_desc=next_state.info["state_descriptor"],
desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
)
return next_state, policy_params, random_key, transition
# Prepare the scoring function
bd_extraction_fn = behavior_descriptor_extractor[config.env.name]
scoring_fn = functools.partial(
scoring_function,
episode_length=config.env.episode_length,
play_reset_fn=reset_fn,
play_step_fn=play_step_fn,
behavior_descriptor_extractor=bd_extraction_fn,
)
# Prepare scan scoring function
def scan_scoring_fn(carry, x):
(random_keys,) = x
fitnesses, descriptors, _, _ = jax.vmap(scoring_fn, in_axes=(None, 0))(repertoire.genotypes, random_keys)
return (), (fitnesses, jnp.linalg.norm(descriptors - repertoire.centroids, axis=-1),)
# Evaluate the repertoire
length = num_evaluations // ENV_MAX_EVALUATIONS[config.env.name]
if length == 0:
random_keys = jax.random.split(random_key, num_evaluations)
fitnesses, descriptors, _, _ = jax.vmap(scoring_fn, in_axes=(None, 0))(repertoire.genotypes, random_keys)
distances = jnp.linalg.norm(descriptors - repertoire.centroids, axis=-1)
# Place nan in fitnesses and distances according to repertoire
fitnesses = jnp.where(jnp.isneginf(repertoire.fitnesses), jnp.nan, fitnesses)
distances = jnp.where(jnp.isneginf(repertoire.fitnesses), jnp.nan, distances)
else:
random_keys = jax.random.split(random_key, num=length*ENV_MAX_EVALUATIONS[config.env.name])
random_keys = jnp.reshape(random_keys, (length, ENV_MAX_EVALUATIONS[config.env.name], -1))
_, (fitnesses, distances,) = jax.lax.scan(
scan_scoring_fn,
(),
(random_keys,),
length=length,
)
# Place nan in fitnesses and distances according to repertoire
fitnesses = jnp.where(jnp.isneginf(repertoire.fitnesses), jnp.nan, fitnesses)
distances = jnp.where(jnp.isneginf(repertoire.fitnesses), jnp.nan, distances)
# Reshape fitnesses and descriptors
fitnesses = jnp.reshape(fitnesses, (num_evaluations, repertoire.centroids.shape[0]))
distances = jnp.reshape(distances, (num_evaluations, repertoire.centroids.shape[0]))
return fitnesses, distances
def actor_reproduciblity(run_dir, num_evaluations):
# Get config
config = get_config(run_dir)
assert num_evaluations <= ENV_MAX_EVALUATIONS[config.env.name] or num_evaluations % ENV_MAX_EVALUATIONS[config.env.name] == 0
# Get repertoire
repertoire = get_repertoire(run_dir)
# Init a random key
random_key = jax.random.PRNGKey(config.seed)
# Init environment
env = get_env(config)
reset_fn = jax.jit(env.reset)
# Init policy network
policy_layer_sizes = config.policy_hidden_layer_sizes + (env.action_size,)
actor_dc_network = MLPDC(
layer_sizes=policy_layer_sizes,
kernel_init=jax.nn.initializers.lecun_uniform(),
final_activation=jnp.tanh,
)
# Get descriptor-conditioned actor
random_key, subkey = jax.random.split(random_key)
fake_obs = jnp.zeros(shape=(env.observation_size,))
fake_desc = jnp.zeros(shape=(env.behavior_descriptor_length,))
actor_dc_params = actor_dc_network.init(subkey, fake_obs, fake_desc)
with open(run_dir / "actor.pickle", "rb") as params_file:
state_dict = pickle.load(params_file)
actor_dc_params = serialization.from_state_dict(actor_dc_params, state_dict)
def normalize_desc(desc):
return 2*(desc - env.behavior_descriptor_limits[0])/(env.behavior_descriptor_limits[1] - env.behavior_descriptor_limits[0]) - 1
# Define the fonction to play a step with the policy in the environment
def play_step_actor_dc_fn(env_state, actor_dc_params, desc, random_key):
desc_prime_normalized = normalize_desc(desc)
actions = actor_dc_network.apply(actor_dc_params, env_state.obs, desc_prime_normalized)
state_desc = env_state.info["state_descriptor"]
next_state = env.step(env_state, actions)
transition = QDTransition(
obs=env_state.obs,
next_obs=next_state.obs,
rewards=next_state.reward,
dones=next_state.done,
truncations=next_state.info["truncation"],
actions=actions,
state_desc=state_desc,
next_state_desc=next_state.info["state_descriptor"],
desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
)
return next_state, actor_dc_params, desc, random_key, transition
# Prepare the scoring function
bd_extraction_fn = behavior_descriptor_extractor[config.env.name]
scoring_fn = jax.jit(functools.partial(
scoring_actor_dc_function,
episode_length=config.env.episode_length,
play_reset_fn=reset_fn,
play_step_actor_dc_fn=play_step_actor_dc_fn,
behavior_descriptor_extractor=bd_extraction_fn,
))
# Prepare scan scoring function
def scan_scoring_fn(carry, x):
(random_keys,) = x
(actor_dc_params,) = carry
fitnesses, descriptors, _, _ = jax.vmap(scoring_fn, in_axes=(None, None, 0))(actor_dc_params, repertoire.centroids, random_keys)
return (actor_dc_params,), (fitnesses, jnp.linalg.norm(descriptors - repertoire.centroids, axis=-1),)
# Evaluate the repertoire
length = num_evaluations // ENV_MAX_EVALUATIONS[config.env.name]
if length == 0:
random_keys = jax.random.split(random_key, num_evaluations)
actor_dc_params = jax.tree_util.tree_map(lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), config.num_centroids, axis=0), actor_dc_params)
fitnesses, descriptors, _, _ = jax.vmap(scoring_fn, in_axes=(None, None, 0))(actor_dc_params, repertoire.centroids, random_keys)
distances = jnp.linalg.norm(descriptors - repertoire.centroids, axis=-1)
# Place nan in fitnesses and distances according to repertoire
fitnesses = jnp.where(jnp.isneginf(repertoire.fitnesses), jnp.nan, fitnesses)
distances = jnp.where(jnp.isneginf(repertoire.fitnesses), jnp.nan, distances)
else:
random_keys = jax.random.split(random_key, num=length*ENV_MAX_EVALUATIONS[config.env.name])
random_keys = jnp.reshape(random_keys, (length, ENV_MAX_EVALUATIONS[config.env.name], -1))
actor_dc_params = jax.tree_util.tree_map(lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), config.num_centroids, axis=0), actor_dc_params)
_, (fitnesses, distances,) = jax.lax.scan(
scan_scoring_fn,
(actor_dc_params,),
(random_keys,),
length=length,
)
# Place nan in fitnesses and distances according to repertoire
fitnesses = jnp.where(jnp.isneginf(repertoire.fitnesses), jnp.nan, fitnesses)
distances = jnp.where(jnp.isneginf(repertoire.fitnesses), jnp.nan, distances)
# Reshape fitnesses and descriptors
fitnesses = jnp.reshape(fitnesses, (num_evaluations, repertoire.centroids.shape[0]))
distances = jnp.reshape(distances, (num_evaluations, repertoire.centroids.shape[0]))
return fitnesses, distances
def repertoire_reproduciblity_algo(algo_dir, actor, num_evaluations):
assert algo_dir.name == "dcrl_me" or not actor, "actor only for dcg_me"
fitnesses_list = []
distances_list = []
for run_dir in algo_dir.iterdir():
start_time = time.time()
# Evaluate repertoire
if actor:
fitnesses, distances = actor_reproduciblity(run_dir, num_evaluations)
else:
fitnesses, distances = repertoire_reproduciblity(run_dir, num_evaluations)
fitnesses_list.append(fitnesses)
distances_list.append(distances)
# Display information
end_time = time.time()
print(len(fitnesses_list), run_dir, f"time: {end_time - start_time}")
return jnp.stack(fitnesses_list), jnp.stack(distances_list)
@hydra.main(version_base=None, config_path="configs/", config_name="config_reproducibility")
def main(config: DictConfig) -> None:
# Evaluate repertoire
results_dir = Path("/src/output/")
# Get algo_dir
if config.algo_name == "dcrl_me_actor":
algo_dir = results_dir / config.env_name / "dcrl_me"
actor = True
else:
actor = False
algo_dir = results_dir / config.env_name / config.algo_name
# Compute repoducbility
fitnesses, distances = repertoire_reproduciblity_algo(algo_dir, actor, config.num_evaluations)
# Save results
with open("./repertoire_fitnesses.pickle", "wb") as fitnesses_file:
pickle.dump(fitnesses, fitnesses_file)
with open("./repertoire_distances.pickle", "wb") as distances_file:
pickle.dump(distances, distances_file)
if __name__ == "__main__":
main()