-
Notifications
You must be signed in to change notification settings - Fork 1
/
set_up_brax.py
executable file
·74 lines (60 loc) · 1.96 KB
/
set_up_brax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Any, Callable, Tuple
import jax.numpy as jnp
import qdax
from qdax import environments
from qdax.tasks.brax_envs import create_brax_scoring_fn
from qdax.types import Genotype, RNGKey
def get_environment_brax(
env_name: str,
episode_length: int,
fixed_init_state: bool,
) -> Any:
# Initialising environment
if env_name == "anttrap":
env = environments.create(
env_name,
episode_length=episode_length,
fixed_init_state=fixed_init_state,
use_contact_forces=False,
exclude_current_positions_from_observation=False,
)
if env_name == "ant_uni":
env = environments.create(
env_name,
episode_length=episode_length,
fixed_init_state=fixed_init_state,
use_contact_forces=False,
)
else:
env = environments.create(
env_name,
episode_length=episode_length,
fixed_init_state=fixed_init_state,
)
return env
def get_policy_struc_brax(
env: Any,
policy_hidden_layer_sizes: Tuple,
) -> Tuple[int, int, Tuple, Callable]:
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
return env.action_size, env.observation_size, policy_layer_sizes, jnp.tanh
def get_reward_offset_brax(env: Any, env_name: str) -> jnp.ndarray:
return environments.reward_offset[env_name]
def get_behavior_descriptor_length_brax(env: Any, env_name: str) -> jnp.ndarray:
return env.behavior_descriptor_length
def get_scoring_function_brax(
env: Any,
env_name: str,
episode_length: int,
policy_network: Genotype,
random_key: RNGKey,
) -> Callable:
bd_extraction_fn = qdax.environments.behavior_descriptor_extractor[env_name]
scoring_fn, random_key = create_brax_scoring_fn(
env,
policy_network,
bd_extraction_fn,
random_key,
episode_length=episode_length,
)
return scoring_fn, random_key