-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow temp storage to grow to fit the largest requested data (#174)
Some background for this change: For some AMP code I need to sample the same memory with two different dataloaders. One loader will always request less data than the other, and if that is sampled before the larger one, the larger data no longer fits in the temp storage. This PR changes so that we grow the temp storage if the data does not fit. Assuming the requested data sizes do not change over time, this will only reallocate at the start until the largest buffersize is sampled. What do you think of this?
- Loading branch information
1 parent
3e02a4b
commit 5b50d8d
Showing
2 changed files
with
78 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from emote.memory.adaptors import TerminalAdaptor | ||
from emote.memory.fifo_strategy import FifoEjectionStrategy | ||
from emote.memory.storage import SyntheticDones | ||
from emote.memory.table import ArrayTable, Column, TagColumn, VirtualColumn | ||
from emote.memory.uniform_strategy import UniformSampleStrategy | ||
|
||
|
||
@pytest.fixture | ||
def table(): | ||
spec = [ | ||
Column(name="obs", dtype=np.float32, shape=(3,)), | ||
Column(name="reward", dtype=np.float32, shape=()), | ||
VirtualColumn("dones", dtype=bool, shape=(1,), target_name="reward", mapper=SyntheticDones), | ||
VirtualColumn( | ||
"masks", | ||
dtype=np.float32, | ||
shape=(1,), | ||
target_name="reward", | ||
mapper=SyntheticDones.as_mask, | ||
), | ||
TagColumn(name="terminal", shape=(), dtype=np.float32), | ||
] | ||
|
||
table = ArrayTable( | ||
columns=spec, | ||
maxlen=10_000, | ||
sampler=UniformSampleStrategy(), | ||
ejector=FifoEjectionStrategy(), | ||
length_key="reward", | ||
adaptors=[TerminalAdaptor("terminal", "masks")], | ||
device="cpu", | ||
) | ||
|
||
return table | ||
|
||
|
||
def test_sampled_data_is_always_copied(table: ArrayTable): | ||
for ii in range(0, 600): | ||
table.add_sequence( | ||
ii, | ||
dict( | ||
obs=[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5]], | ||
reward=[1, 2, 3, 4], | ||
terminal=[0, 0, 0, 0, 0], | ||
), | ||
) | ||
|
||
sample_count = 100 | ||
counts = [256, 512] | ||
seq_len = 3 | ||
for _ in range(sample_count): | ||
for count in counts: | ||
sample1 = table.sample(count, seq_len) | ||
sample2 = table.sample(count, seq_len) | ||
|
||
for key in sample1.keys(): | ||
col_samp_1: torch.Tensor = sample1[key] | ||
col_samp_2: torch.Tensor = sample2[key] | ||
|
||
assert ( | ||
col_samp_1.data_ptr() != col_samp_2.data_ptr() | ||
), "2 table samples share memory! This is not allowed! Samples must always copy their data." |