diff --git a/ding/policy/edt.py b/ding/policy/edt.py new file mode 100644 index 0000000000..1e547589c7 --- /dev/null +++ b/ding/policy/edt.py @@ -0,0 +1,433 @@ +from typing import List, Dict, Any, Tuple, Optional +from collections import namedtuple +import torch.nn.functional as F +import torch +import numpy as np +from ding.torch_utils import to_device +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_decollate +from .base_policy import Policy + + +@POLICY_REGISTRY.register('edt') +class EDTPolicy(Policy): + """ + Overview: + This is the implementation of Elastic Decision Transformer. + Paper link: https://arxiv.org/abs/2307.02484 + """ + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='edt', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool) Whether the RL algorithm is on-policy or off-policy. + on_policy=False, + # (bool) Whether use priority(priority sample, IS weight, update priority) + priority=False, + # (int) N-step reward for target q_value estimation + obs_shape=4, + action_shape=2, + rtg_scale=1000, # normalize returns to go + max_eval_ep_len=1000, # max len of one episode + batch_size=256, # training batch size + wt_decay=1e-4, # decay weight in optimizer + warmup_steps=10000, # steps for learning rate warmup + context_len=20, # length of transformer input + learning_rate=1e-4, + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \ + ``ding.model.template.q_learning``. + """ + return 'edt', ['ding.model.template.edt'] + + def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For Decision Transformer, \ + it mainly contains the optimizer, algorithm-specific arguments such as rtg_scale and lr scheduler. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + # rtg_scale: scale of `return to go` + # rtg_target: max target of `return to go` + # Our goal is normalize `return to go` to (0, 1), which will favour the covergence. + # As a result, we usually set rtg_scale == rtg_target. + self.rtg_scale = self._cfg.rtg_scale # normalize returns to go + self.rtg_target = self._cfg.rtg_target # max target reward_to_go + self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode + + lr = self._cfg.learning_rate # learning rate + wt_decay = self._cfg.wt_decay # weight decay + warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler + + self.clip_grad_norm_p = self._cfg.clip_grad_norm_p + self.context_len = self._cfg.model.context_len # K in decision transformer + + self.state_dim = self._cfg.model.state_dim + self.act_dim = self._cfg.model.act_dim + + self._learn_model = self._model + self._atari_env = 'state_mean' not in self._cfg + self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg + + if self._atari_env: + self._optimizer = self._learn_model.configure_optimizers(wt_decay, lr) + else: + self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) + + self._scheduler = torch.optim.lr_scheduler.LambdaLR( + self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1) + ) + + self.max_env_score = -1.0 + + def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the offline dataset and then returns the output \ + result, including various training information such as loss, current learning rate. + Arguments: + - data (:obj:`List[torch.Tensor]`): The input data used for policy forward, including a series of \ + processed torch.Tensor data, i.e., timesteps, states, actions, returns_to_go, traj_mask. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + """ + self._learn_model.train() + + timesteps, states, actions, returns_to_go, traj_mask = data + + # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), + # and we need a 3-dim tensor + if len(returns_to_go.shape) == 2: + returns_to_go = returns_to_go.unsqueeze(-1) + + if self._basic_discrete_env: + actions = actions.to(torch.long) + actions = actions.squeeze(-1) + action_target = torch.clone(actions).detach().to(self._device) + + if self._atari_env: + state_preds, action_preds, return_preds, return_preds2, reward_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1 + ) + else: + state_preds, action_preds, return_preds, return_preds2, reward_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go + ) + + if self._atari_env: + action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) + else: + traj_mask = traj_mask.view(-1, ) + + # only consider non padded elements + action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0] + + if self._cfg.model.continuous: + action_target = action_target.view(-1, self.act_dim)[traj_mask > 0] + action_loss = F.mse_loss(action_preds, action_target) + else: + action_target = action_target.view(-1)[traj_mask > 0] + action_loss = F.cross_entropy(action_preds, action_target) + + self._optimizer.zero_grad() + action_loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p) + self._optimizer.step() + self._scheduler.step() + + return { + 'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'], + 'action_loss': action_loss.detach().cpu().item(), + 'total_loss': action_loss.detach().cpu().item(), + } + + def _init_eval(self) -> None: + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \ + eval model, some algorithm-specific parameters such as context_len, max_eval_ep_len, etc. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. tip:: + For the evaluation of complete episodes, we need to maintain some historical information for transformer \ + inference. These variables need to be initialized in ``_init_eval`` and reset in ``_reset_eval`` when \ + necessary. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ + self._eval_model = self._model + # init data + self._device = torch.device(self._device) + self.rtg_scale = self._cfg.rtg_scale # normalize returns to go + self.rtg_target = self._cfg.rtg_target # max target reward_to_go + self.state_dim = self._cfg.model.state_dim + self.act_dim = self._cfg.model.act_dim + self.eval_batch_size = self._cfg.evaluator_env_num + self.max_eval_ep_len = self._cfg.max_eval_ep_len + self.context_len = self._cfg.model.context_len # K in decision transformer + + self.t = [0 for _ in range(self.eval_batch_size)] + if self._cfg.model.continuous: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device + ) + else: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device + ) + self._atari_env = 'state_mean' not in self._cfg + self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg + if self._atari_env: + self.states = torch.zeros( + ( + self.eval_batch_size, + self.max_eval_ep_len, + ) + tuple(self.state_dim), + dtype=torch.float32, + device=self._device + ) + self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] + else: + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] + self.states = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device + ) + self.state_mean = torch.from_numpy(np.array(self._cfg.state_mean)).to(self._device) + self.state_std = torch.from_numpy(np.array(self._cfg.state_std)).to(self._device) + self.timesteps = torch.arange( + start=0, end=self.max_eval_ep_len, step=1 + ).repeat(self.eval_batch_size, 1).to(self._device) + self.rewards_to_go = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device + ) + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance, such as interacting with envs. \ + Forward means that the policy gets some input data (current obs/return-to-go and historical information) \ + from the envs and then returns the output data, such as the action to interact with the envs. \ + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs and \ + reward to calculate running return-to-go. The key of the dict is environment id and the value is the \ + corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + Decision Transformer will do different operations for different types of envs in evaluation. + """ + # save and forward + data_id = list(data.keys()) + + self._eval_model.eval() + with torch.no_grad(): + if self._atari_env: + states = torch.zeros( + ( + self.eval_batch_size, + self.context_len, + ) + tuple(self.state_dim), + dtype=torch.float32, + device=self._device + ) + timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self._device) + else: + states = torch.zeros( + (self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device + ) + timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device) + if not self._cfg.model.continuous: + actions = torch.zeros( + (self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device + ) + else: + actions = torch.zeros( + (self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device + ) + rewards_to_go = torch.zeros( + (self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device + ) + for i in data_id: + if self._atari_env: + self.states[i, self.t[i]] = data[i]['obs'].to(self._device) + else: + self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std + self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale).to(self._device) + self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] + + if self.t[i] <= self.context_len: + if self._atari_env: + timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( + (1, 1), dtype=torch.int64 + ).to(self._device) + else: + timesteps[i] = self.timesteps[i, :self.context_len] + states[i] = self.states[i, :self.context_len] + actions[i] = self.actions[i, :self.context_len] + rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] + else: + if self._atari_env: + timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( + (1, 1), dtype=torch.int64 + ).to(self._device) + else: + timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + if self._basic_discrete_env: + actions = actions.squeeze(-1) + _, act_preds, _, _, _= self._eval_model.forward(timesteps, states, actions, rewards_to_go) + del timesteps, states, actions, rewards_to_go + + logits = act_preds[:, -1, :] + if not self._cfg.model.continuous: + if self._atari_env: + probs = F.softmax(logits, dim=-1) + act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device) + for i in data_id: + act[i] = torch.multinomial(probs[i], num_samples=1) + else: + act = torch.argmax(logits, axis=1).unsqueeze(1) + else: + act = logits + for i in data_id: + self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t + self.t[i] += 1 + + if self._cuda: + act = to_device(act, 'cpu') + output = {'action': act} + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset some stateful variables for eval mode when necessary, such as the historical info of transformer \ + for decision transformer. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different environments/episodes in evaluation in ``data_id`` will have different history. + Arguments: + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + specified by ``data_id``. + """ + # clean data + if data_id is None: + self.t = [0 for _ in range(self.eval_batch_size)] + self.timesteps = torch.arange( + start=0, end=self.max_eval_ep_len, step=1 + ).repeat(self.eval_batch_size, 1).to(self._device) + if not self._cfg.model.continuous: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device + ) + else: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), + dtype=torch.float32, + device=self._device + ) + if self._atari_env: + self.states = torch.zeros( + ( + self.eval_batch_size, + self.max_eval_ep_len, + ) + tuple(self.state_dim), + dtype=torch.float32, + device=self._device + ) + self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] + else: + self.states = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), + dtype=torch.float32, + device=self._device + ) + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] + + self.rewards_to_go = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device + ) + else: + for i in data_id: + self.t[i] = 0 + if not self._cfg.model.continuous: + self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) + else: + self.actions[i] = torch.zeros( + (self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device + ) + if self._atari_env: + self.states[i] = torch.zeros( + (self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device + ) + self.running_rtg[i] = self.rtg_target + else: + self.states[i] = torch.zeros( + (self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device + ) + self.running_rtg[i] = self.rtg_target / self.rtg_scale + self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device) + self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ + return ['cur_lr', 'action_loss'] + + def _init_collect(self) -> None: + pass + + def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: + pass + + def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + pass + + def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: + pass diff --git a/dizoo/d4rl/config/halfcheetah_medium_edt_config.py b/dizoo/d4rl/config/halfcheetah_medium_edt_config.py new file mode 100644 index 0000000000..dd48e15ecd --- /dev/null +++ b/dizoo/d4rl/config/halfcheetah_medium_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +halfcheetah_edt_config = dict( + exp_name='edt_log/d4rl/halfcheetah/halfcheetah_medium_edt_seed0', + env=dict( + env_id='HalfCheetah-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/halfcheetah-medium-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='HalfCheetah-v3', + rtg_target=6000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +halfcheetah_edt_config = EasyDict(halfcheetah_edt_config) +main_config = halfcheetah_edt_config +halfcheetah_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +halfcheetah_edt_create_config = EasyDict(halfcheetah_edt_create_config) +create_config = halfcheetah_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_edt_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_edt_config.py new file mode 100644 index 0000000000..80904d1d7d --- /dev/null +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +halfcheetah_dt_config = dict( + exp_name='edt_log/d4rl/halfcheetah/halfcheetah_medium_expert_edt_seed0', + env=dict( + env_id='HalfCheetah-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/halfcheetah-medium-expert-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='HalfCheetah-v3', + rtg_target=6000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +halfcheetah_dt_config = EasyDict(halfcheetah_dt_config) +main_config = halfcheetah_dt_config +halfcheetah_dt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config) +create_config = halfcheetah_dt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_dt + config = deepcopy([main_config, create_config]) + serial_pipeline_dt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/config/halfcheetah_medium_replay_edt_config.py b/dizoo/d4rl/config/halfcheetah_medium_replay_edt_config.py new file mode 100644 index 0000000000..eab048f855 --- /dev/null +++ b/dizoo/d4rl/config/halfcheetah_medium_replay_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +halfcheetah_dt_config = dict( + exp_name='edt_log/d4rl/halfcheetah/halfcheetah_medium_replay_edt_seed0', + env=dict( + env_id='HalfCheetah-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/halfcheetah-medium-replay-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=6000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='HalfCheetah-v3', + rtg_target=6000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +halfcheetah_dt_config = EasyDict(halfcheetah_dt_config) +main_config = halfcheetah_dt_config +halfcheetah_dt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config) +create_config = halfcheetah_dt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_dt + config = deepcopy([main_config, create_config]) + serial_pipeline_dt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/config/hopper_medium_edt_config.py b/dizoo/d4rl/config/hopper_medium_edt_config.py new file mode 100644 index 0000000000..20d5fbed1f --- /dev/null +++ b/dizoo/d4rl/config/hopper_medium_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +hopper_edt_config = dict( + exp_name='edt_log/d4rl/hopper/hopper_medium_edt_seed0', + env=dict( + env_id='Hopper-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=3600, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/hopper-medium-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=3600, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Hopper-v3', + rtg_target=3600, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=11, + act_dim=3, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128,), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +hopper_edt_config = EasyDict(hopper_edt_config) +main_config = hopper_edt_config +hopper_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +hopper_edt_create_config = EasyDict(hopper_edt_create_config) +create_config = hopper_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) \ No newline at end of file diff --git a/dizoo/d4rl/config/hopper_medium_expert_edt_config.py b/dizoo/d4rl/config/hopper_medium_expert_edt_config.py new file mode 100644 index 0000000000..3410cbab7d --- /dev/null +++ b/dizoo/d4rl/config/hopper_medium_expert_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +hopper_edt_config = dict( + exp_name='edt_log/d4rl/hopper/hopper_medium_expert_edt_seed0', + env=dict( + env_id='Hopper-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=3600, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/hopper-medium-expert-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=3600, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Hopper-v3', + rtg_target=3600, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=11, + act_dim=3, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128,), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +hopper_edt_config = EasyDict(hopper_edt_config) +main_config = hopper_edt_config +hopper_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +hopper_edt_create_config = EasyDict(hopper_edt_create_config) +create_config = hopper_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) \ No newline at end of file diff --git a/dizoo/d4rl/config/hopper_medium_replay_edt_config.py b/dizoo/d4rl/config/hopper_medium_replay_edt_config.py new file mode 100644 index 0000000000..b316fb5694 --- /dev/null +++ b/dizoo/d4rl/config/hopper_medium_replay_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +hopper_edt_config = dict( + exp_name='edt_log/d4rl/hopper/hopper_medium_replay_edt_seed0', + env=dict( + env_id='Hopper-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=3600, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/hopper-medium-replay-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=3600, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Hopper-v3', + rtg_target=3600, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=11, + act_dim=3, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128,), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +hopper_edt_config = EasyDict(hopper_edt_config) +main_config = hopper_edt_config +hopper_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +hopper_edt_create_config = EasyDict(hopper_edt_create_config) +create_config = hopper_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) \ No newline at end of file diff --git a/dizoo/d4rl/config/walker2d_medium_edt_config.py b/dizoo/d4rl/config/walker2d_medium_edt_config.py new file mode 100644 index 0000000000..a695da8922 --- /dev/null +++ b/dizoo/d4rl/config/walker2d_medium_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +walk2d_edt_config = dict( + exp_name='edt_log/d4rl/walker2d/walker2d_medium_edt_seed0', + env=dict( + env_id='Walker2d-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=5000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/walker2d-medium-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=5000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Walker2d-v3', + rtg_target=5000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +walk2d_edt_config = EasyDict(walk2d_edt_config) +main_config = walk2d_edt_config +walk2d_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +walk2d_edt_create_config = EasyDict(walk2d_edt_create_config) +create_config = walk2d_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/config/walker2d_medium_expert_edt_config.py b/dizoo/d4rl/config/walker2d_medium_expert_edt_config.py new file mode 100644 index 0000000000..b00f59a54d --- /dev/null +++ b/dizoo/d4rl/config/walker2d_medium_expert_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +walk2d_edt_config = dict( + exp_name='edt_log/d4rl/walker2d/walker2d_medium_expert_edt_seed0', + env=dict( + env_id='Walker2d-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=5000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/walker2d-medium-expert-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=5000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Walker2d-v3', + rtg_target=5000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +walk2d_edt_config = EasyDict(walk2d_edt_config) +main_config = walk2d_edt_config +walk2d_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +walk2d_edt_create_config = EasyDict(walk2d_edt_create_config) +create_config = walk2d_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/config/walker2d_medium_replay_edt_config.py b/dizoo/d4rl/config/walker2d_medium_replay_edt_config.py new file mode 100644 index 0000000000..95afd4b5b2 --- /dev/null +++ b/dizoo/d4rl/config/walker2d_medium_replay_edt_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict +from copy import deepcopy + +walk2d_edt_config = dict( + exp_name='edt_log/d4rl/walker2d/walker2d_medium_replay_edt_seed0', + env=dict( + env_id='Walker2d-v3', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=5000, + ), + dataset=dict( + env_type='mujoco', + rtg_scale=1000, + context_len=20, + data_dir_prefix='/d4rl/walker2d-medium-replay-v2.pkl', + ), + policy=dict( + cuda=True, + stop_value=5000, + state_mean=None, + state_std=None, + evaluator_env_num=8, + env_name='Walker2d-v3', + rtg_target=5000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=4, + h_dim=512, + context_len=20, + n_heads=4, + drop_p=0.1, + max_timestep=4096, + num_bin=60, + dt_mask=False, + rtg_scale=1000, + num_inputs=3, + real_rtg=False, + continuous=True, + ), + learn=dict(batch_size=128), + learning_rate=1e-4, + collect=dict( + data_type='d4rl_trajectory', + unroll_len=1, + ), + eval=dict(evaluator=dict(eval_freq=1000, ), ), + ), +) + +walk2d_edt_config = EasyDict(walk2d_edt_config) +main_config = walk2d_edt_config +walk2d_edt_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='edt'), +) +walk2d_edt_create_config = EasyDict(walk2d_edt_create_config) +create_config = walk2d_edt_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_edt + config = deepcopy([main_config, create_config]) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) diff --git a/dizoo/d4rl/entry/d4rl_edt_mujoco.py b/dizoo/d4rl/entry/d4rl_edt_mujoco.py new file mode 100644 index 0000000000..d56a9dd9ae --- /dev/null +++ b/dizoo/d4rl/entry/d4rl_edt_mujoco.py @@ -0,0 +1,48 @@ +import gym +import torch +import numpy as np +from ditk import logging +from ding.model.template.elastic_decision_transformer import ElasticDecisionTransformer +from ding.policy.edt import EDTPolicy +from ding.envs import BaseEnvManagerV2 +from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper +from ding.data import create_dataset +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OfflineRLContext +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher_from_mem, offline_logger, termination_checker +from ding.utils import set_pkg_seed +from dizoo.d4rl.envs import D4RLEnv +from dizoo.d4rl.config.hopper_medium_edt_config import main_config, create_config + + +def main(): + # If you don't have offline data, you need to prepare if first and set the data_path in config + # For demostration, we also can train a RL policy (e.g. SAC) and collect some data + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OfflineRLContext()): + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: AllinObsWrapper(D4RLEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + dataset = create_dataset(cfg) + # env_data_stats = dataset.get_d4rl_dataset_stats(cfg.policy.dataset_name) + cfg.policy.state_mean, cfg.policy.state_std = dataset.get_state_stats() + model = ElasticDecisionTransformer(**cfg.policy.model) + policy = EDTPolicy(cfg.policy, model=model) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(offline_data_fetcher_from_mem(cfg, dataset)) + task.use(trainer(cfg, policy.learn_mode)) + task.use(termination_checker(max_train_iter=5e4)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000)) + task.use(offline_logger()) + task.run() + + +if __name__ == "__main__": + main()