Skip to content

Commit

Permalink
Allow temp storage to grow to fit the largest requested data (#174)
Browse files Browse the repository at this point in the history
Some background for this change: For some AMP code I need to sample the
same memory with two different dataloaders. One loader will always
request less data than the other, and if that is sampled before the
larger one, the larger data no longer fits in the temp storage. This PR
changes so that we grow the temp storage if the data does not fit.
Assuming the requested data sizes do not change over time, this will
only reallocate at the start until the largest buffersize is sampled.

What do you think of this?
  • Loading branch information
klashenriksson authored Sep 19, 2023
1 parent 3e02a4b commit 5b50d8d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 9 deletions.
21 changes: 12 additions & 9 deletions emote/memory/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ def get_empty_storage(self, count, length):
each time the memory is sampled. Will *not* work if the memory is
sampled from multiple threads.
"""
if self._temp_storage is None:
d = np.empty((count * length, *self._shape), self._dtype)
total_size = count * length
if self._temp_storage is None or self._temp_storage.shape[0] < total_size:
d = np.empty((total_size, *self._shape), self._dtype)
self._temp_storage = d

return self._temp_storage
return self._temp_storage[:total_size]

def sequence_length_transform(self, length):
return length
Expand Down Expand Up @@ -68,11 +69,12 @@ def get_empty_storage(self, count, length):
each time the memory is sampled. Will *not* work if the memory is
sampled from multiple threads.
"""
if self._temp_storage is None:
d = np.empty((count * length, *self._shape), self._dtype)
total_size = count * length
if self._temp_storage is None or self._temp_storage.shape[0] < total_size:
d = np.empty((total_size, *self._shape), self._dtype)
self._temp_storage = d

return self._temp_storage
return self._temp_storage[:total_size]

def sequence_length_transform(self, length):
return 1
Expand Down Expand Up @@ -128,11 +130,12 @@ def sequence_length_transform(self, length):
return length

def get_empty_storage(self, count, length):
if self._temp_storage is None:
d = np.empty((count * length, *self._shape), self._dtype)
total_size = count * length
if self._temp_storage is None or self._temp_storage.shape[0] < total_size:
d = np.empty((total_size, *self._shape), self._dtype)
self._temp_storage = d

return self._temp_storage
return self._temp_storage[:total_size]

def post_import(self):
pass
Expand Down
66 changes: 66 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
import pytest
import torch

from emote.memory.adaptors import TerminalAdaptor
from emote.memory.fifo_strategy import FifoEjectionStrategy
from emote.memory.storage import SyntheticDones
from emote.memory.table import ArrayTable, Column, TagColumn, VirtualColumn
from emote.memory.uniform_strategy import UniformSampleStrategy


@pytest.fixture
def table():
spec = [
Column(name="obs", dtype=np.float32, shape=(3,)),
Column(name="reward", dtype=np.float32, shape=()),
VirtualColumn("dones", dtype=bool, shape=(1,), target_name="reward", mapper=SyntheticDones),
VirtualColumn(
"masks",
dtype=np.float32,
shape=(1,),
target_name="reward",
mapper=SyntheticDones.as_mask,
),
TagColumn(name="terminal", shape=(), dtype=np.float32),
]

table = ArrayTable(
columns=spec,
maxlen=10_000,
sampler=UniformSampleStrategy(),
ejector=FifoEjectionStrategy(),
length_key="reward",
adaptors=[TerminalAdaptor("terminal", "masks")],
device="cpu",
)

return table


def test_sampled_data_is_always_copied(table: ArrayTable):
for ii in range(0, 600):
table.add_sequence(
ii,
dict(
obs=[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5]],
reward=[1, 2, 3, 4],
terminal=[0, 0, 0, 0, 0],
),
)

sample_count = 100
counts = [256, 512]
seq_len = 3
for _ in range(sample_count):
for count in counts:
sample1 = table.sample(count, seq_len)
sample2 = table.sample(count, seq_len)

for key in sample1.keys():
col_samp_1: torch.Tensor = sample1[key]
col_samp_2: torch.Tensor = sample2[key]

assert (
col_samp_1.data_ptr() != col_samp_2.data_ptr()
), "2 table samples share memory! This is not allowed! Samples must always copy their data."

0 comments on commit 5b50d8d

Please sign in to comment.