Skip to content

Commit

Permalink
feature(wrh): add edt policy
Browse files Browse the repository at this point in the history
  • Loading branch information
ruiheng123 committed Jun 25, 2024
1 parent 4ce6583 commit 1acb51d
Show file tree
Hide file tree
Showing 11 changed files with 1,165 additions and 0 deletions.
433 changes: 433 additions & 0 deletions ding/policy/edt.py

Large diffs are not rendered by default.

76 changes: 76 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_edt_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from easydict import EasyDict
from copy import deepcopy

halfcheetah_edt_config = dict(
exp_name='edt_log/d4rl/halfcheetah/halfcheetah_medium_edt_seed0',
env=dict(
env_id='HalfCheetah-v3',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
),
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=20,
data_dir_prefix='/d4rl/halfcheetah-medium-v2.pkl',
),
policy=dict(
cuda=True,
stop_value=6000,
state_mean=None,
state_std=None,
evaluator_env_num=8,
env_name='HalfCheetah-v3',
rtg_target=6000, # max target return to go
max_eval_ep_len=1000, # max lenght of one episode
wt_decay=1e-4,
warmup_steps=10000,
context_len=20,
weight_decay=0.1,
clip_grad_norm_p=0.25,
model=dict(
state_dim=17,
act_dim=6,
n_blocks=4,
h_dim=512,
context_len=20,
n_heads=4,
drop_p=0.1,
max_timestep=4096,
num_bin=60,
dt_mask=False,
rtg_scale=1000,
num_inputs=3,
real_rtg=False,
continuous=True,
),
learn=dict(batch_size=128),
learning_rate=1e-4,
collect=dict(
data_type='d4rl_trajectory',
unroll_len=1,
),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
),
)

halfcheetah_edt_config = EasyDict(halfcheetah_edt_config)
main_config = halfcheetah_edt_config
halfcheetah_edt_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='edt'),
)
halfcheetah_edt_create_config = EasyDict(halfcheetah_edt_create_config)
create_config = halfcheetah_edt_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline_edt
config = deepcopy([main_config, create_config])
serial_pipeline_edt(config, seed=0, max_train_iter=1000)
76 changes: 76 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_expert_edt_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from easydict import EasyDict
from copy import deepcopy

halfcheetah_dt_config = dict(
exp_name='edt_log/d4rl/halfcheetah/halfcheetah_medium_expert_edt_seed0',
env=dict(
env_id='HalfCheetah-v3',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
),
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=20,
data_dir_prefix='/d4rl/halfcheetah-medium-expert-v2.pkl',
),
policy=dict(
cuda=True,
stop_value=6000,
state_mean=None,
state_std=None,
evaluator_env_num=8,
env_name='HalfCheetah-v3',
rtg_target=6000, # max target return to go
max_eval_ep_len=1000, # max lenght of one episode
wt_decay=1e-4,
warmup_steps=10000,
context_len=20,
weight_decay=0.1,
clip_grad_norm_p=0.25,
model=dict(
state_dim=17,
act_dim=6,
n_blocks=4,
h_dim=512,
context_len=20,
n_heads=4,
drop_p=0.1,
max_timestep=4096,
num_bin=60,
dt_mask=False,
rtg_scale=1000,
num_inputs=3,
real_rtg=False,
continuous=True,
),
learn=dict(batch_size=128),
learning_rate=1e-4,
collect=dict(
data_type='d4rl_trajectory',
unroll_len=1,
),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
),
)

halfcheetah_dt_config = EasyDict(halfcheetah_dt_config)
main_config = halfcheetah_dt_config
halfcheetah_dt_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='edt'),
)
halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config)
create_config = halfcheetah_dt_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline_dt
config = deepcopy([main_config, create_config])
serial_pipeline_dt(config, seed=0, max_train_iter=1000)
76 changes: 76 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_replay_edt_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from easydict import EasyDict
from copy import deepcopy

halfcheetah_dt_config = dict(
exp_name='edt_log/d4rl/halfcheetah/halfcheetah_medium_replay_edt_seed0',
env=dict(
env_id='HalfCheetah-v3',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
),
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=20,
data_dir_prefix='/d4rl/halfcheetah-medium-replay-v2.pkl',
),
policy=dict(
cuda=True,
stop_value=6000,
state_mean=None,
state_std=None,
evaluator_env_num=8,
env_name='HalfCheetah-v3',
rtg_target=6000, # max target return to go
max_eval_ep_len=1000, # max lenght of one episode
wt_decay=1e-4,
warmup_steps=10000,
context_len=20,
weight_decay=0.1,
clip_grad_norm_p=0.25,
model=dict(
state_dim=17,
act_dim=6,
n_blocks=4,
h_dim=512,
context_len=20,
n_heads=4,
drop_p=0.1,
max_timestep=4096,
num_bin=60,
dt_mask=False,
rtg_scale=1000,
num_inputs=3,
real_rtg=False,
continuous=True,
),
learn=dict(batch_size=128),
learning_rate=1e-4,
collect=dict(
data_type='d4rl_trajectory',
unroll_len=1,
),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
),
)

halfcheetah_dt_config = EasyDict(halfcheetah_dt_config)
main_config = halfcheetah_dt_config
halfcheetah_dt_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='edt'),
)
halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config)
create_config = halfcheetah_dt_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline_dt
config = deepcopy([main_config, create_config])
serial_pipeline_dt(config, seed=0, max_train_iter=1000)
76 changes: 76 additions & 0 deletions dizoo/d4rl/config/hopper_medium_edt_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from easydict import EasyDict
from copy import deepcopy

hopper_edt_config = dict(
exp_name='edt_log/d4rl/hopper/hopper_medium_edt_seed0',
env=dict(
env_id='Hopper-v3',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=3600,
),
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=20,
data_dir_prefix='/d4rl/hopper-medium-v2.pkl',
),
policy=dict(
cuda=True,
stop_value=3600,
state_mean=None,
state_std=None,
evaluator_env_num=8,
env_name='Hopper-v3',
rtg_target=3600, # max target return to go
max_eval_ep_len=1000, # max lenght of one episode
wt_decay=1e-4,
warmup_steps=10000,
context_len=20,
weight_decay=0.1,
clip_grad_norm_p=0.25,
model=dict(
state_dim=11,
act_dim=3,
n_blocks=4,
h_dim=512,
context_len=20,
n_heads=4,
drop_p=0.1,
max_timestep=4096,
num_bin=60,
dt_mask=False,
rtg_scale=1000,
num_inputs=3,
real_rtg=False,
continuous=True,
),
learn=dict(batch_size=128,),
learning_rate=1e-4,
collect=dict(
data_type='d4rl_trajectory',
unroll_len=1,
),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
),
)

hopper_edt_config = EasyDict(hopper_edt_config)
main_config = hopper_edt_config
hopper_edt_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='edt'),
)
hopper_edt_create_config = EasyDict(hopper_edt_create_config)
create_config = hopper_edt_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline_edt
config = deepcopy([main_config, create_config])
serial_pipeline_edt(config, seed=0, max_train_iter=1000)
76 changes: 76 additions & 0 deletions dizoo/d4rl/config/hopper_medium_expert_edt_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from easydict import EasyDict
from copy import deepcopy

hopper_edt_config = dict(
exp_name='edt_log/d4rl/hopper/hopper_medium_expert_edt_seed0',
env=dict(
env_id='Hopper-v3',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=3600,
),
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=20,
data_dir_prefix='/d4rl/hopper-medium-expert-v2.pkl',
),
policy=dict(
cuda=True,
stop_value=3600,
state_mean=None,
state_std=None,
evaluator_env_num=8,
env_name='Hopper-v3',
rtg_target=3600, # max target return to go
max_eval_ep_len=1000, # max lenght of one episode
wt_decay=1e-4,
warmup_steps=10000,
context_len=20,
weight_decay=0.1,
clip_grad_norm_p=0.25,
model=dict(
state_dim=11,
act_dim=3,
n_blocks=4,
h_dim=512,
context_len=20,
n_heads=4,
drop_p=0.1,
max_timestep=4096,
num_bin=60,
dt_mask=False,
rtg_scale=1000,
num_inputs=3,
real_rtg=False,
continuous=True,
),
learn=dict(batch_size=128,),
learning_rate=1e-4,
collect=dict(
data_type='d4rl_trajectory',
unroll_len=1,
),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
),
)

hopper_edt_config = EasyDict(hopper_edt_config)
main_config = hopper_edt_config
hopper_edt_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='edt'),
)
hopper_edt_create_config = EasyDict(hopper_edt_create_config)
create_config = hopper_edt_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline_edt
config = deepcopy([main_config, create_config])
serial_pipeline_edt(config, seed=0, max_train_iter=1000)
Loading

0 comments on commit 1acb51d

Please sign in to comment.