diff --git a/serl_launcher/serl_launcher/data/serl_memory_efficient_replay_buffer.py b/serl_launcher/serl_launcher/data/serl_memory_efficient_replay_buffer.py index 1b6e15bd..0ae6052c 100644 --- a/serl_launcher/serl_launcher/data/serl_memory_efficient_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/serl_memory_efficient_replay_buffer.py @@ -24,12 +24,12 @@ def __init__( for pixel_key in self.pixel_keys: pixel_obs_space = observation_space.spaces[pixel_key] if self._num_stack is None: - self._num_stack = pixel_obs_space.shape[-1] + self._num_stack = pixel_obs_space.shape[0] else: - assert self._num_stack == pixel_obs_space.shape[-1] - self._unstacked_dim_size = pixel_obs_space.shape[-2] - low = pixel_obs_space.low[..., 0] - high = pixel_obs_space.high[..., 0] + assert self._num_stack == pixel_obs_space.shape[0] + self._unstacked_dim_size = pixel_obs_space.shape[-1] + low = pixel_obs_space.low[0] + high = pixel_obs_space.high[0] unstacked_pixel_obs_space = Box( low=low, high=high, dtype=pixel_obs_space.dtype ) @@ -71,13 +71,13 @@ def insert(self, data_dict: DatasetDict): if self._first: for i in range(self._num_stack): for pixel_key in self.pixel_keys: - data_dict["observations"][pixel_key] = obs_pixels[pixel_key][..., i] + data_dict["observations"][pixel_key] = obs_pixels[pixel_key][i] self._is_correct_index[self._insert_index] = False super().insert(data_dict) for pixel_key in self.pixel_keys: - data_dict["observations"][pixel_key] = next_obs_pixels[pixel_key][..., -1] + data_dict["observations"][pixel_key] = next_obs_pixels[pixel_key][-1] self._first = data_dict["dones"] @@ -151,12 +151,13 @@ def sample( obs_pixels, self._num_stack + 1, axis=0 ) obs_pixels = obs_pixels[indx - self._num_stack] + obs_pixels = obs_pixels.transpose((0, 4, 1, 2, 3)) if pack_obs_and_next_obs: batch["observations"][pixel_key] = obs_pixels else: - batch["observations"][pixel_key] = obs_pixels[..., :-1] + batch["observations"][pixel_key] = obs_pixels[:, :-1, ...] if "next_observations" in keys: - batch["next_observations"][pixel_key] = obs_pixels[..., 1:] + batch["next_observations"][pixel_key] = obs_pixels[:, 1:, ...] return frozen_dict.freeze(batch)