Skip to content

Commit

Permalink
[FEAT][zeta.rl]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 24, 2023
1 parent 1bb7da1 commit bb82269
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 61 deletions.
13 changes: 12 additions & 1 deletion tests/rl/test_prioritizedreplybuffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest
import random
import torch
from zeta.rl.PrioritizedReplayBuffer import PrioritizedReplayBuffer, SumTree # Replace 'your_module' with the actual module where classes are defined
from zeta.rl.priortized_replay_buffer import (
PrioritizedReplayBuffer,
SumTree,
) # Replace 'your_module' with the actual module where classes are defined


@pytest.fixture
def replay_buffer():
Expand All @@ -11,6 +15,7 @@ def replay_buffer():
device = torch.device("cpu")
return PrioritizedReplayBuffer(state_size, action_size, buffer_size, device)


def test_initialization(replay_buffer):
assert replay_buffer.eps == 1e-2
assert replay_buffer.alpha == 0.1
Expand All @@ -21,12 +26,14 @@ def test_initialization(replay_buffer):
assert replay_buffer.size == 100
assert replay_buffer.device == torch.device("cpu")


def test_add(replay_buffer):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
replay_buffer.add(transition)
assert replay_buffer.count == 1
assert replay_buffer.real_size == 1


def test_sample(replay_buffer):
for i in range(10):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
Expand All @@ -37,6 +44,7 @@ def test_sample(replay_buffer):
assert len(weights) == 5
assert len(tree_idxs) == 5


def test_update_priorities(replay_buffer):
for i in range(10):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
Expand All @@ -46,10 +54,12 @@ def test_update_priorities(replay_buffer):
new_priorities = torch.rand(5)
replay_buffer.update_priorities(tree_idxs, new_priorities)


def test_sample_with_invalid_batch_size(replay_buffer):
with pytest.raises(AssertionError):
replay_buffer.sample(101)


def test_add_with_max_size(replay_buffer):
for i in range(100):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
Expand All @@ -58,4 +68,5 @@ def test_add_with_max_size(replay_buffer):
assert replay_buffer.count == 0
assert replay_buffer.real_size == 100


# Additional tests for edge cases, exceptions, and more scenarios can be added as needed.
17 changes: 15 additions & 2 deletions tests/rl/test_prioritizedsequencereplybuffer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import pytest
import random
import torch
from zeta.rl.PrioritizedSequenceReplayBuffer import PrioritizedSequenceReplayBuffer, SumTree # Replace 'your_module' with the actual module where classes are defined
from zeta.rl.priortized_rps import (
PrioritizedSequenceReplayBuffer,
SumTree,
) # Replace 'your_module' with the actual module where classes are defined


@pytest.fixture
def replay_buffer():
state_size = 4
action_size = 2
buffer_size = 100
device = torch.device("cpu")
return PrioritizedSequenceReplayBuffer(state_size, action_size, buffer_size, device)
return PrioritizedSequenceReplayBuffer(
state_size, action_size, buffer_size, device
)


def test_initialization(replay_buffer):
assert replay_buffer.eps == 1e-5
Expand All @@ -24,12 +31,14 @@ def test_initialization(replay_buffer):
assert replay_buffer.size == 100
assert replay_buffer.device == torch.device("cpu")


def test_add(replay_buffer):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
replay_buffer.add(transition)
assert replay_buffer.count == 1
assert replay_buffer.real_size == 1


def test_sample(replay_buffer):
for i in range(10):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
Expand All @@ -40,6 +49,7 @@ def test_sample(replay_buffer):
assert len(weights) == 5
assert len(tree_idxs) == 5


def test_update_priorities(replay_buffer):
for i in range(10):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
Expand All @@ -49,10 +59,12 @@ def test_update_priorities(replay_buffer):
new_priorities = torch.rand(5)
replay_buffer.update_priorities(tree_idxs, new_priorities)


def test_sample_with_invalid_batch_size(replay_buffer):
with pytest.raises(AssertionError):
replay_buffer.sample(101)


def test_add_with_max_size(replay_buffer):
for i in range(100):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
Expand All @@ -61,4 +73,5 @@ def test_add_with_max_size(replay_buffer):
assert replay_buffer.count == 0
assert replay_buffer.real_size == 100


# Additional tests for edge cases, exceptions, and more scenarios can be added as needed.
24 changes: 19 additions & 5 deletions tests/rl/test_sumtree.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import pytest
from zeta.rl.sumtree import SumTree # Replace 'your_module' with the actual module where SumTree is defined
from zeta.rl.sumtree import (
SumTree,
) # Replace 'your_module' with the actual module where SumTree is defined


# Fixture for initializing SumTree instances with a given size
@pytest.fixture
def sum_tree():
size = 10 # You can change the size as needed
return SumTree(size)


# Basic tests
def test_initialization(sum_tree):
assert sum_tree.size == 10
assert sum_tree.count == 0
assert sum_tree.real_size == 0
assert sum_tree.total == 0


def test_update_and_get(sum_tree):
sum_tree.add(5, "data1")
assert sum_tree.total == 5
Expand All @@ -22,35 +27,44 @@ def test_update_and_get(sum_tree):
assert priority == 5
assert data == "data1"


def test_add_overflow(sum_tree):
for i in range(15):
sum_tree.add(i, f"data{i}")
assert sum_tree.count == 5
assert sum_tree.real_size == 10


# Parameterized testing for various scenarios
@pytest.mark.parametrize("values, expected_total", [
([1, 2, 3, 4, 5], 15),
([10, 20, 30, 40, 50], 150),
])
@pytest.mark.parametrize(
"values, expected_total",
[
([1, 2, 3, 4, 5], 15),
([10, 20, 30, 40, 50], 150),
],
)
def test_multiple_updates(sum_tree, values, expected_total):
for value in values:
sum_tree.add(value, None)
assert sum_tree.total == expected_total


# Exception testing
def test_get_with_invalid_cumsum(sum_tree):
with pytest.raises(AssertionError):
sum_tree.get(20)


# More tests for specific methods
def test_get_priority(sum_tree):
sum_tree.add(10, "data1")
priority = sum_tree.get_priority(0)
assert priority == 10


def test_repr(sum_tree):
expected_repr = f"SumTree(nodes={sum_tree.nodes}, data={sum_tree.data})"
assert repr(sum_tree) == expected_repr


# More test cases can be added as needed
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,43 @@
import torch
import random


class PrioritizedReplayBuffer:
def __init__(self, state_size, action_size, buffer_size, device, eps=1e-2, alpha=0.1, beta=0.1):
def __init__(
self,
state_size,
action_size,
buffer_size,
device,
eps=1e-2,
alpha=0.1,
beta=0.1,
):
"""
Initializes a PrioritizedReplayBuffer object.
Args:
state_size (int): The size of the state space.
action_size (int): The size of the action space.
buffer_size (int): The maximum capacity of the buffer.
device (torch.device): The device to store the tensors on.
eps (float, optional): A small constant added to the priorities to ensure non-zero probabilities. Defaults to 1e-2.
alpha (float, optional): The exponent used to compute the priority weights. Defaults to 0.1.
beta (float, optional): The exponent used to compute the importance sampling weights. Defaults to 0.1.
"""
self.tree = SumTree(size=buffer_size)


self.eps = eps
self.alpha = alpha
self.beta = beta
self.max_priority = 1.

self.eps = eps
self.alpha = alpha
self.beta = beta
self.max_priority = 1.0

self.state = torch.empty(buffer_size, state_size, dtype=torch.float)
self.action = torch.empty(buffer_size, action_size, dtype=torch.float)
self.reward = torch.empty(buffer_size, dtype=torch.float)
self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float)
self.next_state = torch.empty(
buffer_size, state_size, dtype=torch.float
)
self.done = torch.empty(buffer_size, dtype=torch.uint8)

self.count = 0
Expand All @@ -25,10 +47,15 @@ def __init__(self, state_size, action_size, buffer_size, device, eps=1e-2, alpha

# device
self.device = device

def add(self, transition):
state, action, reward, next_state, done = transition
"""
Adds a transition to the replay buffer.
Args:
transition (tuple): A tuple containing the state, action, reward, next_state, and done flag.
"""
state, action, reward, next_state, done = transition

self.tree.add(self.max_priority, self.count)

Expand All @@ -38,23 +65,32 @@ def add(self, transition):
self.next_state[self.count] = torch.as_tensor(next_state)
self.done[self.count] = torch.as_tensor(done)


self.count = (self.count + 1) % self.size
self.real_size = min(self.size, self.real_size + 1)

def sample(self, batch_size):
assert self.real_size >= batch_size, "buffer contains less samples than batch size"
"""
Samples a batch of transitions from the replay buffer.
Args:
batch_size (int): The size of the batch to sample.
Returns:
tuple: A tuple containing the batch of transitions, importance sampling weights, and tree indices.
"""
assert (
self.real_size >= batch_size
), "buffer contains fewer samples than batch size"

sample_idxs, tree_idxs = [], []
priorities = torch.empty(batch_size, 1, dtype=torch.float)


segment = self.tree.total / batch_size
for i in range(batch_size):
a, b = segment * i, segment * (i + 1)

cumsum = random.uniform(a, b)

tree_idx, priority, sample_idx = self.tree.get(cumsum)

priorities[i] = priority
Expand All @@ -71,15 +107,22 @@ def sample(self, batch_size):
self.action[sample_idxs].to(self.device),
self.reward[sample_idxs].to(self.device),
self.next_state[sample_idxs].to(self.device),
self.done[sample_idxs].to(self.device)
self.done[sample_idxs].to(self.device),
)
return batch, weights, tree_idxs

def update_priorities(self, data_idxs, priorities):
"""
Updates the priorities of the transitions in the replay buffer.
Args:
data_idxs (list): A list of indices corresponding to the transitions in the replay buffer.
priorities (torch.Tensor or numpy.ndarray): The updated priorities for the corresponding transitions.
"""
if isinstance(priorities, torch.Tensor):
priorities = priorities.detach().cpu().numpy()

for data_idx, priority in zip(data_idxs, priorities):
priority = (priority + self.eps) ** self.alpha
self.tree.update(data_idx, priority)
self.max_priority = max(self.max_priority, priority)
self.max_priority = max(self.max_priority, priority)
Loading

0 comments on commit bb82269

Please sign in to comment.