From 485264b3b174ada066f330373400afdd1b80af70 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Mon, 1 Jul 2024 20:07:00 +0800 Subject: [PATCH] add edt policy --- ding/entry/tests/test_serial_entry.py | 2 + ding/policy/edt.py | 455 +++++++++++++----- ding/utils/data/dataset.py | 284 +++++++++++ .../config/halfcheetah_medium_edt_config.py | 18 +- dizoo/d4rl/config/hopper_medium_edt_config.py | 46 +- .../d4rl/config/walker2d_medium_edt_config.py | 20 +- 6 files changed, 680 insertions(+), 145 deletions(-) diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index d36f6bc717..c6e537883d 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -651,7 +651,9 @@ def test_discrete_dt(): from ding.data import create_dataset from ding.config import compile_config from ding.model import DecisionTransformer + from ding.model.template.elastic_decision_transformer import ElasticDecisionTransformer from ding.policy import DTPolicy + from ding.policy.edt import EDTPolicy from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \ OfflineMemoryDataFetcher, offline_logger, termination_checker ding_init(config[0]) diff --git a/ding/policy/edt.py b/ding/policy/edt.py index 1e547589c7..cebe1d65c5 100644 --- a/ding/policy/edt.py +++ b/ding/policy/edt.py @@ -1,4 +1,5 @@ from typing import List, Dict, Any, Tuple, Optional +import math from collections import namedtuple import torch.nn.functional as F import torch @@ -8,13 +9,28 @@ from ding.utils.data import default_decollate from .base_policy import Policy +REF_MIN_SCORE = { + 'halfcheetah' : -280.178953, + 'walker2d' : 1.629008, + 'hopper' : -20.272305, + 'ant' : -325.6, + 'antmaze' : 0.0, +} + +REF_MAX_SCORE = { + 'halfcheetah' : 12135.0, + 'walker2d' : 4592.3, + 'hopper' : 3234.3, + 'ant' : 3879.7, + 'antmaze' : 700, +} @POLICY_REGISTRY.register('edt') class EDTPolicy(Policy): """ Overview: - This is the implementation of Elastic Decision Transformer. - Paper link: https://arxiv.org/abs/2307.02484 + Policy class of Decision Transformer algorithm in discrete environments. + Paper link: https://arxiv.org/abs/2106.01345. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -30,7 +46,7 @@ class EDTPolicy(Policy): action_shape=2, rtg_scale=1000, # normalize returns to go max_eval_ep_len=1000, # max len of one episode - batch_size=256, # training batch size + batch_size=64, # training batch size wt_decay=1e-4, # decay weight in optimizer warmup_steps=10000, # steps for learning rate warmup context_len=20, # length of transformer input @@ -74,19 +90,31 @@ def _init_learn(self) -> None: # rtg_target: max target of `return to go` # Our goal is normalize `return to go` to (0, 1), which will favour the covergence. # As a result, we usually set rtg_scale == rtg_target. + self.env_name = self._cfg.env_id + self.rtg_scale = self._cfg.rtg_scale # normalize returns to go self.rtg_target = self._cfg.rtg_target # max target reward_to_go self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode - + + self.expectile = self._cfg.weights.expectile + self.top_percentile = self._cfg.weights.top_percentile + self.expert_weight = self._cfg.weights.expert_weight + self.exp_loss_weight = self._cfg.weights.exp_loss_weight + self.state_loss_weight = self._cfg.weights.state_loss_weight + self.cross_entropy_weight = self._cfg.weights.cross_entropy_weight + + + lr = self._cfg.learning_rate # learning rate wt_decay = self._cfg.wt_decay # weight decay warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler self.clip_grad_norm_p = self._cfg.clip_grad_norm_p + self.context_len = self._cfg.model.context_len # K in decision transformer - self.state_dim = self._cfg.model.state_dim self.act_dim = self._cfg.model.act_dim + self.num_bin = self._cfg.model.num_bin # num of bin self._learn_model = self._model self._atari_env = 'state_mean' not in self._cfg @@ -126,27 +154,61 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: """ self._learn_model.train() - timesteps, states, actions, returns_to_go, traj_mask = data + timesteps, states, next_states, actions, returns_to_go, rewards, traj_mask = data # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), # and we need a 3-dim tensor if len(returns_to_go.shape) == 2: returns_to_go = returns_to_go.unsqueeze(-1) + if len(rewards.shape) == 2: + rewards = rewards.unsqueeze(-1) + # Guarantee return and reward has shape [B, T, 1] if self._basic_discrete_env: actions = actions.to(torch.long) actions = actions.squeeze(-1) - action_target = torch.clone(actions).detach().to(self._device) + action_target = torch.clone(actions).detach().to(self._device) # [B, T, A] + state_target = torch.clone(states).detach().to(self._device) # [B, T, S] + return_to_go_target = torch.clone(returns_to_go).detach().to(self._device) if self._atari_env: - state_preds, action_preds, return_preds, return_preds2, reward_preds = self._learn_model.forward( + state_preds, action_preds, return_preds, imp_return_preds, reward_preds = self._learn_model.forward( timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1 ) else: - state_preds, action_preds, return_preds, return_preds2, reward_preds = self._learn_model.forward( + state_preds, action_preds, return_preds, imp_return_preds, reward_preds = self._learn_model.forward( timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go ) - + def expectile_loss(diff: torch.Tensor, expectile: float=0.8) -> torch.Tensor: + weight = torch.where(diff > 0, expectile, (1 - expectile)) + return weight * (diff**2) + + def cross_entropy(logits, labels): + # labels = F.one_hot(labels.long(), num_classes=int(num_bin)).squeeze(2) + labels = F.one_hot( + labels.long(), num_classes=int(self.num_bin) + ).squeeze() + criterion = torch.nn.CrossEntropyLoss() + return criterion(logits, labels.float()) + + def encode_return(env_name, ret, scale=1.0, num_bin=120, rtg_scale=1000): + env_key = env_name.split("-")[0].lower() + if env_key not in REF_MAX_SCORE: + ret_max = 100 + else: + ret_max = REF_MAX_SCORE[env_key] + if env_key not in REF_MIN_SCORE: + ret_min = -20 + else: + ret_min = REF_MIN_SCORE[env_key] + ret_max /= rtg_scale + ret_min /= rtg_scale + interval = (ret_max - ret_min) / (num_bin-1) + ret = torch.clip(ret, ret_min, ret_max) + return ((ret - ret_min) // interval).float() + + + if self._atari_env: action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) else: @@ -154,16 +216,45 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: # only consider non padded elements action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0] - + state_preds = state_preds.view(-1, self.state_dim)[traj_mask > 0] + imp_return_preds = imp_return_preds.reshape(-1, 1)[traj_mask > 0] + return_preds = return_preds.reshape(-1, int(self.num_bin))[traj_mask > 0] + + if self._cfg.model.continuous: action_target = action_target.view(-1, self.act_dim)[traj_mask > 0] action_loss = F.mse_loss(action_preds, action_target) + state_target = next_states.view(-1, self.state_dim)[traj_mask > 0] + state_loss = F.mse_loss(state_preds, state_target) + imp_return_target = returns_to_go.reshape(-1, 1)[traj_mask > 0] + imp_loss = expectile_loss((imp_return_target - imp_return_preds), self.expectile).mean() + return_target = ( + encode_return( + self.env_name, + returns_to_go, + num_bin=self.num_bin, + rtg_scale=self.rtg_scale, + ).float().reshape(-1, 1)[traj_mask > 0] + ) + return_cross_entropy_loss = cross_entropy(return_preds, return_target) + else: action_target = action_target.view(-1)[traj_mask > 0] action_loss = F.cross_entropy(action_preds, action_target) - + state_target = next_states.view(-1)[traj_mask > 0] + state_loss = F.cross_entropy(state_preds, state_target) + imp_return_target = returns_to_go.reshape(-1, 1)[traj_mask > 0] + imp_loss = expectile_loss((imp_return_target - imp_return_preds), self.expectile).mean() + + edt_loss = action_loss \ + + state_loss * self.state_loss_weight \ + + imp_loss * self.exp_loss_weight \ + + if self._cfg.model.continuous: + edt_loss += return_cross_entropy_loss * self.cross_entropy_weight + self._optimizer.zero_grad() - action_loss.backward() + edt_loss.backward() if self._cfg.multi_gpu: self.sync_gradients(self._learn_model) torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p) @@ -173,7 +264,9 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: return { 'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'], 'action_loss': action_loss.detach().cpu().item(), - 'total_loss': action_loss.detach().cpu().item(), + 'state_loss': state_loss.detach().cpu().item(), + 'implict_loss': imp_loss.detach().cpu().item(), + 'total_loss': edt_loss.detach().cpu().item(), } def _init_eval(self) -> None: @@ -195,6 +288,8 @@ def _init_eval(self) -> None: self._eval_model = self._model # init data self._device = torch.device(self._device) + + self.real_rtg = self._cfg.real_rtg self.rtg_scale = self._cfg.rtg_scale # normalize returns to go self.rtg_target = self._cfg.rtg_target # max target reward_to_go self.state_dim = self._cfg.model.state_dim @@ -202,42 +297,147 @@ def _init_eval(self) -> None: self.eval_batch_size = self._cfg.evaluator_env_num self.max_eval_ep_len = self._cfg.max_eval_ep_len self.context_len = self._cfg.model.context_len # K in decision transformer - + self.expectile = self._cfg.weights.expectile + + self.rs_steps = self._cfg.eval.rs_steps + self.rs_ratio = self._cfg.weights.rs_ratio + self.heuristic = self._cfg.eval.heuristic + self.heuristic_delta = self._cfg.eval.heuristic_delta + + + self.t = [0 for _ in range(self.eval_batch_size)] if self._cfg.model.continuous: - self.actions = torch.zeros( - (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device - ) + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, self.act_dim), + dtype=torch.float32, device=self._device) else: - self.actions = torch.zeros( - (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device - ) + # (B, eval_len + 2 * context_len, A) for actions + self.actions = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, 1), + dtype=torch.long, device=self._device) + self._atari_env = 'state_mean' not in self._cfg self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg + if self._atari_env: - self.states = torch.zeros( - ( - self.eval_batch_size, - self.max_eval_ep_len, - ) + tuple(self.state_dim), - dtype=torch.float32, - device=self._device - ) self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len,) + tuple(self.state_dim), + dtype=torch.float32, device=self._device) else: + # (B, eval_len + 2 * context_len, S) for states self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] - self.states = torch.zeros( - (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device - ) + self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, self.state_dim), + dtype=torch.float32, device=self._device) self.state_mean = torch.from_numpy(np.array(self._cfg.state_mean)).to(self._device) self.state_std = torch.from_numpy(np.array(self._cfg.state_std)).to(self._device) - self.timesteps = torch.arange( - start=0, end=self.max_eval_ep_len, step=1 - ).repeat(self.eval_batch_size, 1).to(self._device) - self.rewards_to_go = torch.zeros( - (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device - ) - + + + self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len + 2 * self.context_len, step=1) + self.timesteps = self.timesteps.repeat(self.eval_batch_size, 1).to(self._device) + + # (B, eval_len + 2 * context_len, 1) for rtg & rewards + self.rewards_to_go = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, 1), + dtype=torch.float32, device=self._device) + self.rewards = torch.zeros((self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, 1), + dtype=torch.float32, device=self._device) + + def decode_return(env_name: str, ret, scale: float=1.0, num_bin: int=120, rtg_scale: int=1000): + env_key = env_name.split("-")[0].lower() + if env_key not in REF_MAX_SCORE: + ret_max = 100 + else: + ret_max = REF_MAX_SCORE[env_key] + if env_key not in REF_MIN_SCORE: + ret_min = -20 + else: + ret_min = REF_MIN_SCORE[env_key] + ret_max /= rtg_scale + ret_min /= rtg_scale + interval = (ret_max - ret_min) / num_bin + return ret * interval + ret_min + + def _return_heuristic(self, + model: torch.nn.Module, + timesteps: torch.Tensor, + states: torch.Tensor, + actions: torch.Tensor, + rewards_to_go: torch.Tensor, + rewards: torch.Tensor, + context_len: int, + t: int, + # top_percentile: float, + # num_bin: int, + # rtg_scale: int, + # expert_weight: float, + # mgdt_sampling: bool = False, + rs_steps: int = 2, + rs_ratio: int = 1, + real_rtg: bool = False, + use_heuristic: bool = False, + heuristic_delta: int = 1, + previous_index: Optional[int] = None, + ) -> Tuple[torch.Tensor, int]: + highest_ret = -9999 + estimated_rtg = None + best_i = 0 + best_act = None + if t < context_len: + for i in range(0, math.ceil((t + 1) / rs_ratio), rs_steps): + _, act_preds, ret_preds, imp_ret_preds, _ = model.forward( + timesteps[:, i : context_len + i], + states[:, i : context_len + i], + actions[:, i : context_len + i], + rewards_to_go[:, i : context_len + i], + rewards[:, i : context_len + i], + ) + _, act_preds, ret_preds, imp_ret_preds_pure, _ = model.forward( + timesteps[:, i : context_len + i], + states[:, i : context_len + i], + actions[:, i : context_len + i], + imp_ret_preds, + rewards[:, i : context_len + i], + ) + if not real_rtg: + imp_ret_preds = imp_ret_preds_pure + ret_i = imp_ret_preds[:, t - i].detach().item() + if ret_i > highest_ret: + highest_ret = ret_i + best_i = i + estimated_rtg = imp_ret_preds.detach() + best_act = act_preds[0, t - i].detach() + else: + if use_heuristic: + prev_best_index = context_len - previous_index + loop = (prev_best_index-heuristic_delta, prev_best_index+1+heuristic_delta) + else: + loop = (0, math.ceil(context_len/rs_ratio), rs_steps) + for i in range(*loop): + if use_heuristic and (i < 0 or i >= context_len): + continue + _, act_preds, ret_preds, imp_ret_preds, _ = model.forward( + timesteps[:, t - context_len + 1 + i : t + 1 + i], + states[:, t - context_len + 1 + i : t + 1 + i], + actions[:, t - context_len + 1 + i : t + 1 + i], + rewards_to_go[:, t - context_len + 1 + i : t + 1 + i], + rewards[:, t - context_len + 1 + i : t + 1 + i], + ) + _, act_preds, ret_preds, imp_ret_preds_pure, _ = model.forward( + timesteps[:, t - context_len + 1 + i : t + 1 + i], + states[:, t - context_len + 1 + i : t + 1 + i], + actions[:, t - context_len + 1 + i : t + 1 + i], + imp_ret_preds, + rewards[:, t - context_len + 1 + i : t + 1 + i], + ) + if not real_rtg: + imp_ret_preds = imp_ret_preds_pure + + ret_i = imp_ret_preds[:, -1 - i].detach().item() + if ret_i > highest_ret: + highest_ret = ret_i + best_i = i + # estimated_rtg = imp_ret_preds.detach() + best_act = act_preds[0, -1 - i].detach() + return best_act, context_len - best_i + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: """ Overview: @@ -256,94 +456,94 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: Decision Transformer will do different operations for different types of envs in evaluation. """ # save and forward + + data_id = list(data.keys()) self._eval_model.eval() - with torch.no_grad(): - if self._atari_env: - states = torch.zeros( - ( - self.eval_batch_size, - self.context_len, - ) + tuple(self.state_dim), - dtype=torch.float32, - device=self._device - ) - timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self._device) - else: - states = torch.zeros( - (self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device - ) - timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device) - if not self._cfg.model.continuous: - actions = torch.zeros( - (self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device - ) - else: - actions = torch.zeros( - (self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device - ) - rewards_to_go = torch.zeros( - (self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device - ) + with torch.no_grad(): + + print(self.t) + best_acts = [] for i in data_id: - if self._atari_env: - self.states[i, self.t[i]] = data[i]['obs'].to(self._device) - else: - self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std - self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale).to(self._device) - self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] - - if self.t[i] <= self.context_len: - if self._atari_env: - timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( - (1, 1), dtype=torch.int64 - ).to(self._device) - else: - timesteps[i] = self.timesteps[i, :self.context_len] - states[i] = self.states[i, :self.context_len] - actions[i] = self.actions[i, :self.context_len] - rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] - else: + curr_states = self.states[i].unsqueeze(0) + curr_runninng_rtg = self.running_rtg[i] + curr_rewards_to_go = self.rewards_to_go[i].unsqueeze(0) + curr_rewards = self.rewards[i].unsqueeze(0) + curr_actions = self.actions[i].unsqueeze(0) + previous_index = None + for t in range(self.max_eval_ep_len): if self._atari_env: - timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( - (1, 1), dtype=torch.int64 - ).to(self._device) + curr_states[0, t] = data[i]['obs'].to(self._device) else: - timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] - states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] - actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1] - rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1] - if self._basic_discrete_env: - actions = actions.squeeze(-1) - _, act_preds, _, _, _= self._eval_model.forward(timesteps, states, actions, rewards_to_go) - del timesteps, states, actions, rewards_to_go - - logits = act_preds[:, -1, :] - if not self._cfg.model.continuous: - if self._atari_env: - probs = F.softmax(logits, dim=-1) - act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device) - for i in data_id: - act[i] = torch.multinomial(probs[i], num_samples=1) - else: - act = torch.argmax(logits, axis=1).unsqueeze(1) - else: - act = logits + curr_states[0, t] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std + # print(f"curr_states[0, t] 的 shape 是 {curr_states[0, t].shape}, 而 state的shape是{curr_states.shape}") + curr_runninng_rtg = curr_runninng_rtg - (data[i]['reward'] / self.rtg_scale).to(self._device) + curr_rewards_to_go[0, t] = curr_runninng_rtg + curr_rewards[0, t] = data[i]['reward'] + act, best_index = self._return_heuristic( + model=self._eval_model, + timesteps=self.timesteps[i].unsqueeze(0), + states=curr_states, + actions=curr_actions, + rewards_to_go=curr_rewards_to_go, + rewards=curr_rewards, + context_len=self.context_len, + t=t, + rs_steps=self.rs_steps, + rs_ratio=self.rs_ratio, + real_rtg=self.real_rtg, + use_heuristic=self.heuristic, + heuristic_delta=self.heuristic_delta, + previous_index=previous_index + ) + previous_index = best_index + best_acts.append(act) + acts = torch.stack(best_acts, dim=0) + print(f"acts has shape {acts.shape}") + # previous_index = None + # for t in range(self.max_eval_ep_len): + # if self._atari_env: + # self.states[0, t] = data[0]['obs'].to(self._device) + # else: + # self.states[0, t] = (data[0]['obs'].to(self._device) - self.state_mean) / self.state_std + # print(f"self.states[0, t] 的 shape 是 {self.states[0, t].shape}, 而 state的shape是{self.states.shape}") + # self.running_rtg[0] = self.running_rtg[0] - (data[0]['reward'] / self.rtg_scale).to(self._device) + # self.rewards_to_go[0, t] = self.running_rtg[0] + # self.rewards[0, t] = data[0]['reward'] + # act, best_index = self._return_heuristic( + # model=self._eval_model, + # timesteps=self.timesteps, + # states=self.states, + # actions=self.actions, + # rewards_to_go=self.rewards_to_go, + # rewards=self.rewards, + # context_len=self.context_len, + # t=t, + # rs_steps=self.rs_steps, + # rs_ratio=self.rs_ratio, + # real_rtg=self.real_rtg, + # use_heuristic=self.heuristic, + # heuristic_delta=self.heuristic_delta, + # previous_index=previous_index + # ) + # previous_index = best_index + # act = act.unsqueeze(0) + # # print(f"{t} ended! act has shape {act.shape}") for i in data_id: - self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t - self.t[i] += 1 - - if self._cuda: - act = to_device(act, 'cpu') - output = {'action': act} - output = default_decollate(output) + self.actions[i, self.t[i]] = acts[i] # TODO: self.actions[i] should be a queue when exceed max_t + self.t[i] += 1 + + if self._cuda: + acts = to_device(acts, 'cpu') + output = {'action': acts} + output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: """ Overview: - Reset some stateful variables for eval mode when necessary, such as the historical info of transformer \ + Reset some statvaeful riables for eval mode when necessary, such as the historical info of transformer \ for decision transformer. If ``data_id`` is None, it means to reset all the stateful \ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ different environments/episodes in evaluation in ``data_id`` will have different history. @@ -355,15 +555,15 @@ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: if data_id is None: self.t = [0 for _ in range(self.eval_batch_size)] self.timesteps = torch.arange( - start=0, end=self.max_eval_ep_len, step=1 + start=0, end=self.max_eval_ep_len + 2 * self.context_len, step=1 ).repeat(self.eval_batch_size, 1).to(self._device) if not self._cfg.model.continuous: self.actions = torch.zeros( - (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device + (self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, 1), dtype=torch.long, device=self._device ) else: self.actions = torch.zeros( - (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), + (self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, self.act_dim), dtype=torch.float32, device=self._device ) @@ -371,7 +571,7 @@ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: self.states = torch.zeros( ( self.eval_batch_size, - self.max_eval_ep_len, + self.max_eval_ep_len + 2 * self.context_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device @@ -379,36 +579,37 @@ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] else: self.states = torch.zeros( - (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), + (self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, self.state_dim), dtype=torch.float32, device=self._device ) self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] self.rewards_to_go = torch.zeros( - (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device + (self.eval_batch_size, self.max_eval_ep_len + 2 * self.context_len, 1), dtype=torch.float32, device=self._device ) else: for i in data_id: self.t[i] = 0 if not self._cfg.model.continuous: - self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) + self.actions[i] = torch.zeros((self.max_eval_ep_len + 2 * self.context_len, 1), dtype=torch.long, device=self._device) else: self.actions[i] = torch.zeros( - (self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device + (self.max_eval_ep_len + 2 * self.context_len, self.act_dim), dtype=torch.float32, device=self._device ) if self._atari_env: self.states[i] = torch.zeros( - (self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device + (self.max_eval_ep_len + 2 * self.context_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device ) self.running_rtg[i] = self.rtg_target else: self.states[i] = torch.zeros( - (self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device + (self.max_eval_ep_len + 2 * self.context_len, self.state_dim), dtype=torch.float32, device=self._device ) self.running_rtg[i] = self.rtg_target / self.rtg_scale - self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device) - self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) + self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len + 2 * self.context_len, step=1).to(self._device) + self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len + 2 * self.context_len, 1), dtype=torch.float32, device=self._device) + self.rewards[i] = torch.zeros((self.max_eval_ep_len + 2 * self.context_len, 1), dtype=torch.float32, device=self._device) def _monitor_vars_learn(self) -> List[str]: """ diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index f29dd3335a..cb0bc04046 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -843,7 +843,291 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso traj_mask = torch.ones(self.context_len, dtype=torch.long) return timesteps, states, actions, rtgs, traj_mask +@DATASET_REGISTRY.register('edt_d4rl_trajectory') +class EDTTrajectoryDataset(D4RLTrajectoryDataset): + """ + Overview: + D4RL trajectory dataset for EDT, which is used for offline RL algorithms. + Interfaces: + ``__init__``, ``__len__``, ``__getitem__`` + """ + def __init__(self, cfg: dict) -> None: + """ + Overview: + Initialization method. + Arguments: + - cfg (:obj:`dict`): Config dict. + """ + dataset_path = cfg.dataset.data_dir_prefix + rtg_scale = cfg.dataset.rtg_scale + self.context_len = cfg.dataset.context_len + self.env_type = cfg.dataset.env_type + + if 'hdf5' in dataset_path: # for mujoco env + try: + import h5py + import collections + except ImportError: + import sys + logging.warning("not found h5py package, please install it trough `pip install h5py ") + sys.exit(1) + dataset = h5py.File(dataset_path, 'r') + + N = dataset['rewards'].shape[0] + data_ = collections.defaultdict(list) + + use_timeouts = False + if 'timeouts' in dataset: + use_timeouts = True + + episode_step = 0 + paths = [] + for i in range(N): + done_bool = bool(dataset['terminals'][i]) + if use_timeouts: + final_timestep = dataset['timeouts'][i] + else: + final_timestep = (episode_step == 1000 - 1) + for k in ['observations', 'actions', 'rewards', 'terminals']: + data_[k].append(dataset[k][i]) + if done_bool or final_timestep: + episode_step = 0 + episode_data = {} + for k in data_: + episode_data[k] = np.array(data_[k]) + paths.append(episode_data) + data_ = collections.defaultdict(list) + episode_step += 1 + + self.trajectories = paths + + + # calculate state mean and variance and returns_to_go for all traj + states = [] + for traj in self.trajectories: + traj_len = traj['observations'].shape[0] + states.append(traj['observations']) + # calculate returns to go and rescale them + traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale + # used for input normalization + states = np.concatenate(states, axis=0) + self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + # normalize states + for traj in self.trajectories: + traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std + traj['next_observations'] = (traj['next_observations'] - self.state_mean) / self.state_std + + elif 'pkl' in dataset_path: + if 'dqn' in dataset_path: + # load dataset + with open(dataset_path, 'rb') as f: + self.trajectories = pickle.load(f) + + if isinstance(self.trajectories[0], list): + # for our collected dataset, e.g. cartpole/lunarlander case + trajectories_tmp = [] + + original_keys = ['obs', 'next_obs', 'action', 'reward'] + keys = ['observations', 'next_observations', 'actions', 'rewards'] + trajectories_tmp = [ + { + key: np.stack( + [ + self.trajectories[eps_index][transition_index][o_key] + for transition_index in range(len(self.trajectories[eps_index])) + ], + axis=0 + ) + for key, o_key in zip(keys, original_keys) + } for eps_index in range(len(self.trajectories)) + ] + self.trajectories = trajectories_tmp + + states = [] + for traj in self.trajectories: + # traj_len = traj['observations'].shape[0] + states.append(traj['observations']) + # calculate returns to go and rescale them + traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale + + # used for input normalization + states = np.concatenate(states, axis=0) + self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + # normalize states + for traj in self.trajectories: + traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std + traj['next_observations'] = (traj['next_observations'] - self.state_mean) / self.state_std + else: + # load dataset + with open(dataset_path, 'rb') as f: + self.trajectories = pickle.load(f) + + states = [] + for traj in self.trajectories: + states.append(traj['observations']) + # calculate returns to go and rescale them + traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale + + # used for input normalization + states = np.concatenate(states, axis=0) + self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + # normalize states + for traj in self.trajectories: + traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std + traj['next_observations'] = (traj['next_observations'] - self.state_mean) / self.state_std + else: + # -- load data from memory (make more efficient) + obss = [] + actions = [] + returns = [0] + done_idxs = [] + stepwise_returns = [] + + transitions_per_buffer = np.zeros(50, dtype=int) + num_trajectories = 0 + while len(obss) < cfg.dataset.num_steps: + buffer_num = np.random.choice(np.arange(50 - cfg.dataset.num_buffers, 50), 1)[0] + i = transitions_per_buffer[buffer_num] + frb = FixedReplayBuffer( + data_dir=cfg.dataset.data_dir_prefix + '/1/replay_logs', + replay_suffix=buffer_num, + observation_shape=(84, 84), + stack_size=4, + update_horizon=1, + gamma=0.99, + observation_dtype=np.uint8, + batch_size=32, + replay_capacity=100000 + ) + if frb._loaded_buffers: + done = False + curr_num_transitions = len(obss) + trajectories_to_load = cfg.dataset.trajectories_per_buffer + while not done: + states, ac, ret, next_states, next_action, next_reward, terminal, indices = \ + frb.sample_transition_batch(batch_size=1, indices=[i]) + states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) + obss.append(states) + actions.append(ac[0]) + stepwise_returns.append(ret[0]) + if terminal[0]: + done_idxs.append(len(obss)) + returns.append(0) + if trajectories_to_load == 0: + done = True + else: + trajectories_to_load -= 1 + returns[-1] += ret[0] + i += 1 + if i >= 100000: + obss = obss[:curr_num_transitions] + actions = actions[:curr_num_transitions] + stepwise_returns = stepwise_returns[:curr_num_transitions] + returns[-1] = 0 + i = transitions_per_buffer[buffer_num] + done = True + num_trajectories += (cfg.dataset.trajectories_per_buffer - trajectories_to_load) + transitions_per_buffer[buffer_num] = i + + actions = np.array(actions) + returns = np.array(returns) + stepwise_returns = np.array(stepwise_returns) + done_idxs = np.array(done_idxs) + + # -- create reward-to-go dataset + start_index = 0 + rtg = np.zeros_like(stepwise_returns) + for i in done_idxs: + i = int(i) + curr_traj_returns = stepwise_returns[start_index:i] + for j in range(i - 1, start_index - 1, -1): # start from i-1 + rtg_j = curr_traj_returns[j - start_index:i - start_index] + rtg[j] = sum(rtg_j) + start_index = i + + # -- create timestep dataset + start_index = 0 + timesteps = np.zeros(len(actions) + 1, dtype=int) + for i in done_idxs: + i = int(i) + timesteps[start_index:i + 1] = np.arange(i + 1 - start_index) + start_index = i + 1 + + self.obss = obss + self.actions = actions + self.done_idxs = done_idxs + self.rtgs = rtg + self.timesteps = timesteps + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Overview: + Get the item of the dataset. + Arguments: + - idx (:obj:`int`): The index of the dataset. + """ + if self.env_type != 'atari': + traj = self.trajectories[idx] + traj_len = traj['observations'].shape[0] + + if traj_len > self.context_len: + si = np.random.randint(0, traj_len - self.context_len) + states = torch.from_numpy(traj['observations'][si:si + self.context_len]) + next_states = torch.from_numpy(traj["next_observations"][si:si + self.context_len]) + actions = torch.from_numpy(traj['actions'][si:si + self.context_len]) + returns_to_go = torch.from_numpy(traj['returns_to_go'][si:si + self.context_len]) + rewards = torch.from_numpy(traj["rewards"][si : si + self.context_len]) + timesteps = torch.arange(start=si, end=si + self.context_len, step=1) + + # all ones since no padding + traj_mask = torch.ones(self.context_len, dtype=torch.long) + else: + padding_len = self.context_len - traj_len + + # padding with zeros + states = torch.from_numpy(traj['observations']) + states = torch.cat( + [states, torch.zeros(([padding_len] + list(states.shape[1:])), dtype=states.dtype)], dim=0 + ) + + next_states = torch.from_numpy(traj['next_observations']) + next_states = torch.cat( + [next_states, torch.zeros(([padding_len] + list(next_states.shape[1:])), dtype=states.dtype)], dim=0 + ) + + actions = torch.from_numpy(traj['actions']) + actions = torch.cat( + [actions, torch.zeros(([padding_len] + list(actions.shape[1:])), dtype=actions.dtype)], dim=0 + ) + + returns_to_go = torch.from_numpy(traj['returns_to_go']) + returns_to_go = torch.cat( + [ + returns_to_go, + torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), dtype=returns_to_go.dtype) + ], + dim=0 + ) + + rewards = torch.from_numpy(traj["rewards"]) + rewards = torch.cat( + [ + rewards, + torch.zeros(([padding_len] + list(rewards.shape[1:])), dtype=rewards.dtype,), + ], + dim=0 + ) + timesteps = torch.arange(start=0, end=self.context_len, step=1) + + traj_mask = torch.cat( + [torch.ones(traj_len, dtype=torch.long), + torch.zeros(padding_len, dtype=torch.long)], dim=0 + ) + return timesteps, states, next_states, actions, returns_to_go, rewards, traj_mask @DATASET_REGISTRY.register('d4rl_diffuser') class D4RLDiffuserDataset(Dataset): """ diff --git a/dizoo/d4rl/config/halfcheetah_medium_edt_config.py b/dizoo/d4rl/config/halfcheetah_medium_edt_config.py index dd48e15ecd..526d6922ce 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_edt_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_edt_config.py @@ -18,7 +18,9 @@ data_dir_prefix='/d4rl/halfcheetah-medium-v2.pkl', ), policy=dict( + env_id='HalfCheetah-v3', cuda=True, + real_rtg=False, stop_value=6000, state_mean=None, state_std=None, @@ -49,11 +51,25 @@ ), learn=dict(batch_size=128), learning_rate=1e-4, + weights=dict( + top_percentile=0.15, + expectile=0.99, + expert_weight=10, + exp_loss_weight=0.5, + state_loss_weight=1.0, + cross_entropy_weight=0.001, + rs_ratio=1, # between 1 and 2 + + ), collect=dict( data_type='d4rl_trajectory', unroll_len=1, ), - eval=dict(evaluator=dict(eval_freq=1000, ), ), + eval=dict( + evaluator=dict(eval_freq=1000, ), + rs_steps=2, + heuristic=False, + heuristic_delta=2), ), ) diff --git a/dizoo/d4rl/config/hopper_medium_edt_config.py b/dizoo/d4rl/config/hopper_medium_edt_config.py index 20d5fbed1f..c00ebe44c1 100644 --- a/dizoo/d4rl/config/hopper_medium_edt_config.py +++ b/dizoo/d4rl/config/hopper_medium_edt_config.py @@ -1,31 +1,33 @@ from easydict import EasyDict from copy import deepcopy -hopper_edt_config = dict( +hopper_dt_config = dict( exp_name='edt_log/d4rl/hopper/hopper_medium_edt_seed0', env=dict( env_id='Hopper-v3', collector_env_num=1, - evaluator_env_num=8, + evaluator_env_num=2, use_act_scale=True, - n_evaluator_episode=8, + n_evaluator_episode=2, stop_value=3600, ), dataset=dict( env_type='mujoco', rtg_scale=1000, context_len=20, - data_dir_prefix='/d4rl/hopper-medium-v2.pkl', + data_dir_prefix='/d4rl/hopper-medium-v2.pkl', #! This points out the directory of dataset ), policy=dict( + env_id='Hopper-v3', + real_rtg=False, cuda=True, stop_value=3600, state_mean=None, state_std=None, - evaluator_env_num=8, + evaluator_env_num=2, #! the evaluator env num in policy should be equal to env env_name='Hopper-v3', rtg_target=3600, # max target return to go - max_eval_ep_len=1000, # max lenght of one episode + max_eval_ep_len=20, # max lenght of one episode wt_decay=1e-4, warmup_steps=10000, context_len=20, @@ -34,10 +36,10 @@ model=dict( state_dim=11, act_dim=3, - n_blocks=4, + n_blocks=3, h_dim=512, context_len=20, - n_heads=4, + n_heads=1, drop_p=0.1, max_timestep=4096, num_bin=60, @@ -49,17 +51,31 @@ ), learn=dict(batch_size=128,), learning_rate=1e-4, + weights=dict( + top_percentile=0.15, + expectile=0.99, + expert_weight=10, + exp_loss_weight=0.5, + state_loss_weight=1.0, + cross_entropy_weight=0.001, + rs_ratio=1, # between 1 and 2 + + ), collect=dict( - data_type='d4rl_trajectory', + data_type='edt_d4rl_trajectory', unroll_len=1, ), - eval=dict(evaluator=dict(eval_freq=1000, ), ), + eval=dict( + evaluator=dict(eval_freq=1000, ), + rs_steps=2, + heuristic=False, + heuristic_delta=2), ), ) -hopper_edt_config = EasyDict(hopper_edt_config) -main_config = hopper_edt_config -hopper_edt_create_config = dict( +hopper_dt_config = EasyDict(hopper_dt_config) +main_config = hopper_dt_config +hopper_dt_create_config = dict( env=dict( type='mujoco', import_names=['dizoo.mujoco.envs.mujoco_env'], @@ -67,8 +83,8 @@ env_manager=dict(type='subprocess'), policy=dict(type='edt'), ) -hopper_edt_create_config = EasyDict(hopper_edt_create_config) -create_config = hopper_edt_create_config +hopper_dt_create_config = EasyDict(hopper_dt_create_config) +create_config = hopper_dt_create_config if __name__ == "__main__": from ding.entry import serial_pipeline_edt diff --git a/dizoo/d4rl/config/walker2d_medium_edt_config.py b/dizoo/d4rl/config/walker2d_medium_edt_config.py index a695da8922..3d2f31123b 100644 --- a/dizoo/d4rl/config/walker2d_medium_edt_config.py +++ b/dizoo/d4rl/config/walker2d_medium_edt_config.py @@ -18,6 +18,8 @@ data_dir_prefix='/d4rl/walker2d-medium-v2.pkl', ), policy=dict( + env_id='Hopper-v3', + real_rtg=False, cuda=True, stop_value=5000, state_mean=None, @@ -49,11 +51,25 @@ ), learn=dict(batch_size=128), learning_rate=1e-4, + weights=dict( + top_percentile=0.15, + expectile=0.99, + expert_weight=10, + exp_loss_weight=0.5, + state_loss_weight=1.0, + cross_entropy_weight=0.001, + rs_ratio=1, # between 1 and 2 + + ), collect=dict( data_type='d4rl_trajectory', unroll_len=1, ), - eval=dict(evaluator=dict(eval_freq=1000, ), ), + eval=dict( + evaluator=dict(eval_freq=1000, ), + rs_steps=2, + heuristic=False, + heuristic_delta=2), ), ) @@ -73,4 +89,4 @@ if __name__ == "__main__": from ding.entry import serial_pipeline_edt config = deepcopy([main_config, create_config]) - serial_pipeline_edt(config, seed=0, max_train_iter=1000) + serial_pipeline_edt(config, seed=0, max_train_iter=1000) \ No newline at end of file