Skip to content

Commit

Permalink
frame stack bug fix for memory efficient buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
Leo428 committed Dec 21, 2023
1 parent 6cf7d58 commit 82e439e
Showing 1 changed file with 10 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ def __init__(
for pixel_key in self.pixel_keys:
pixel_obs_space = observation_space.spaces[pixel_key]
if self._num_stack is None:
self._num_stack = pixel_obs_space.shape[-1]
self._num_stack = pixel_obs_space.shape[0]
else:
assert self._num_stack == pixel_obs_space.shape[-1]
self._unstacked_dim_size = pixel_obs_space.shape[-2]
low = pixel_obs_space.low[..., 0]
high = pixel_obs_space.high[..., 0]
assert self._num_stack == pixel_obs_space.shape[0]
self._unstacked_dim_size = pixel_obs_space.shape[-1]
low = pixel_obs_space.low[0]
high = pixel_obs_space.high[0]
unstacked_pixel_obs_space = Box(
low=low, high=high, dtype=pixel_obs_space.dtype
)
Expand Down Expand Up @@ -71,13 +71,13 @@ def insert(self, data_dict: DatasetDict):
if self._first:
for i in range(self._num_stack):
for pixel_key in self.pixel_keys:
data_dict["observations"][pixel_key] = obs_pixels[pixel_key][..., i]
data_dict["observations"][pixel_key] = obs_pixels[pixel_key][i]

self._is_correct_index[self._insert_index] = False
super().insert(data_dict)

for pixel_key in self.pixel_keys:
data_dict["observations"][pixel_key] = next_obs_pixels[pixel_key][..., -1]
data_dict["observations"][pixel_key] = next_obs_pixels[pixel_key][-1]

self._first = data_dict["dones"]

Expand Down Expand Up @@ -151,12 +151,13 @@ def sample(
obs_pixels, self._num_stack + 1, axis=0
)
obs_pixels = obs_pixels[indx - self._num_stack]
obs_pixels = obs_pixels.transpose((0, 4, 1, 2, 3))

if pack_obs_and_next_obs:
batch["observations"][pixel_key] = obs_pixels
else:
batch["observations"][pixel_key] = obs_pixels[..., :-1]
batch["observations"][pixel_key] = obs_pixels[:, :-1, ...]
if "next_observations" in keys:
batch["next_observations"][pixel_key] = obs_pixels[..., 1:]
batch["next_observations"][pixel_key] = obs_pixels[:, 1:, ...]

return frozen_dict.freeze(batch)

0 comments on commit 82e439e

Please sign in to comment.