diff --git a/tests/rl/test_prioritizedreplybuffer.py b/tests/rl/test_prioritizedreplybuffer.py new file mode 100644 index 00000000..dba5637b --- /dev/null +++ b/tests/rl/test_prioritizedreplybuffer.py @@ -0,0 +1,61 @@ +import pytest +import random +import torch +from zeta.rl.PrioritizedReplayBuffer import PrioritizedReplayBuffer, SumTree # Replace 'your_module' with the actual module where classes are defined + +@pytest.fixture +def replay_buffer(): + state_size = 4 + action_size = 2 + buffer_size = 100 + device = torch.device("cpu") + return PrioritizedReplayBuffer(state_size, action_size, buffer_size, device) + +def test_initialization(replay_buffer): + assert replay_buffer.eps == 1e-2 + assert replay_buffer.alpha == 0.1 + assert replay_buffer.beta == 0.1 + assert replay_buffer.max_priority == 1.0 + assert replay_buffer.count == 0 + assert replay_buffer.real_size == 0 + assert replay_buffer.size == 100 + assert replay_buffer.device == torch.device("cpu") + +def test_add(replay_buffer): + transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) + replay_buffer.add(transition) + assert replay_buffer.count == 1 + assert replay_buffer.real_size == 1 + +def test_sample(replay_buffer): + for i in range(10): + transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) + replay_buffer.add(transition) + + batch, weights, tree_idxs = replay_buffer.sample(5) + assert len(batch) == 5 + assert len(weights) == 5 + assert len(tree_idxs) == 5 + +def test_update_priorities(replay_buffer): + for i in range(10): + transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) + replay_buffer.add(transition) + + batch, weights, tree_idxs = replay_buffer.sample(5) + new_priorities = torch.rand(5) + replay_buffer.update_priorities(tree_idxs, new_priorities) + +def test_sample_with_invalid_batch_size(replay_buffer): + with pytest.raises(AssertionError): + replay_buffer.sample(101) + +def test_add_with_max_size(replay_buffer): + for i in range(100): + transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) + replay_buffer.add(transition) + + assert replay_buffer.count == 0 + assert replay_buffer.real_size == 100 + +# Additional tests for edge cases, exceptions, and more scenarios can be added as needed. diff --git a/tests/rl/test_prioritizedsequencereplybuffer.py b/tests/rl/test_prioritizedsequencereplybuffer.py new file mode 100644 index 00000000..9582dc71 --- /dev/null +++ b/tests/rl/test_prioritizedsequencereplybuffer.py @@ -0,0 +1,64 @@ +import pytest +import random +import torch +from zeta.rl.PrioritizedSequenceReplayBuffer import PrioritizedSequenceReplayBuffer, SumTree # Replace 'your_module' with the actual module where classes are defined + +@pytest.fixture +def replay_buffer(): + state_size = 4 + action_size = 2 + buffer_size = 100 + device = torch.device("cpu") + return PrioritizedSequenceReplayBuffer(state_size, action_size, buffer_size, device) + +def test_initialization(replay_buffer): + assert replay_buffer.eps == 1e-5 + assert replay_buffer.alpha == 0.1 + assert replay_buffer.beta == 0.1 + assert replay_buffer.max_priority == 1.0 + assert replay_buffer.decay_window == 5 + assert replay_buffer.decay_coff == 0.4 + assert replay_buffer.pre_priority == 0.7 + assert replay_buffer.count == 0 + assert replay_buffer.real_size == 0 + assert replay_buffer.size == 100 + assert replay_buffer.device == torch.device("cpu") + +def test_add(replay_buffer): + transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) + replay_buffer.add(transition) + assert replay_buffer.count == 1 + assert replay_buffer.real_size == 1 + +def test_sample(replay_buffer): + for i in range(10): + transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) + replay_buffer.add(transition) + + batch, weights, tree_idxs = replay_buffer.sample(5) + assert len(batch) == 5 + assert len(weights) == 5 + assert len(tree_idxs) == 5 + +def test_update_priorities(replay_buffer): + for i in range(10): + transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) + replay_buffer.add(transition) + + batch, weights, tree_idxs = replay_buffer.sample(5) + new_priorities = torch.rand(5) + replay_buffer.update_priorities(tree_idxs, new_priorities) + +def test_sample_with_invalid_batch_size(replay_buffer): + with pytest.raises(AssertionError): + replay_buffer.sample(101) + +def test_add_with_max_size(replay_buffer): + for i in range(100): + transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) + replay_buffer.add(transition) + + assert replay_buffer.count == 0 + assert replay_buffer.real_size == 100 + +# Additional tests for edge cases, exceptions, and more scenarios can be added as needed. diff --git a/tests/rl/test_sumtree.py b/tests/rl/test_sumtree.py new file mode 100644 index 00000000..7758f9b8 --- /dev/null +++ b/tests/rl/test_sumtree.py @@ -0,0 +1,56 @@ +import pytest +from zeta.rl.sumtree import SumTree # Replace 'your_module' with the actual module where SumTree is defined + +# Fixture for initializing SumTree instances with a given size +@pytest.fixture +def sum_tree(): + size = 10 # You can change the size as needed + return SumTree(size) + +# Basic tests +def test_initialization(sum_tree): + assert sum_tree.size == 10 + assert sum_tree.count == 0 + assert sum_tree.real_size == 0 + assert sum_tree.total == 0 + +def test_update_and_get(sum_tree): + sum_tree.add(5, "data1") + assert sum_tree.total == 5 + data_idx, priority, data = sum_tree.get(5) + assert data_idx == 0 + assert priority == 5 + assert data == "data1" + +def test_add_overflow(sum_tree): + for i in range(15): + sum_tree.add(i, f"data{i}") + assert sum_tree.count == 5 + assert sum_tree.real_size == 10 + +# Parameterized testing for various scenarios +@pytest.mark.parametrize("values, expected_total", [ + ([1, 2, 3, 4, 5], 15), + ([10, 20, 30, 40, 50], 150), +]) +def test_multiple_updates(sum_tree, values, expected_total): + for value in values: + sum_tree.add(value, None) + assert sum_tree.total == expected_total + +# Exception testing +def test_get_with_invalid_cumsum(sum_tree): + with pytest.raises(AssertionError): + sum_tree.get(20) + +# More tests for specific methods +def test_get_priority(sum_tree): + sum_tree.add(10, "data1") + priority = sum_tree.get_priority(0) + assert priority == 10 + +def test_repr(sum_tree): + expected_repr = f"SumTree(nodes={sum_tree.nodes}, data={sum_tree.data})" + assert repr(sum_tree) == expected_repr + +# More test cases can be added as needed diff --git a/zeta/rl/PrioritizedReplayBuffer.py b/zeta/rl/PrioritizedReplayBuffer.py new file mode 100644 index 00000000..badb3a7e --- /dev/null +++ b/zeta/rl/PrioritizedReplayBuffer.py @@ -0,0 +1,85 @@ +from sumtree import SumTree +import torch +import random + +class PrioritizedReplayBuffer: + def __init__(self, state_size, action_size, buffer_size, device, eps=1e-2, alpha=0.1, beta=0.1): + self.tree = SumTree(size=buffer_size) + + + self.eps = eps + self.alpha = alpha + self.beta = beta + self.max_priority = 1. + + + self.state = torch.empty(buffer_size, state_size, dtype=torch.float) + self.action = torch.empty(buffer_size, action_size, dtype=torch.float) + self.reward = torch.empty(buffer_size, dtype=torch.float) + self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float) + self.done = torch.empty(buffer_size, dtype=torch.uint8) + + self.count = 0 + self.real_size = 0 + self.size = buffer_size + + # device + self.device = device + + def add(self, transition): + state, action, reward, next_state, done = transition + + + self.tree.add(self.max_priority, self.count) + + self.state[self.count] = torch.as_tensor(state) + self.action[self.count] = torch.as_tensor(action) + self.reward[self.count] = torch.as_tensor(reward) + self.next_state[self.count] = torch.as_tensor(next_state) + self.done[self.count] = torch.as_tensor(done) + + + self.count = (self.count + 1) % self.size + self.real_size = min(self.size, self.real_size + 1) + + def sample(self, batch_size): + assert self.real_size >= batch_size, "buffer contains less samples than batch size" + + sample_idxs, tree_idxs = [], [] + priorities = torch.empty(batch_size, 1, dtype=torch.float) + + + segment = self.tree.total / batch_size + for i in range(batch_size): + a, b = segment * i, segment * (i + 1) + + cumsum = random.uniform(a, b) + + tree_idx, priority, sample_idx = self.tree.get(cumsum) + + priorities[i] = priority + tree_idxs.append(tree_idx) + sample_idxs.append(sample_idx) + + probs = priorities / self.tree.total + + weights = (self.real_size * probs) ** -self.beta + + weights = weights / weights.max() + batch = ( + self.state[sample_idxs].to(self.device), + self.action[sample_idxs].to(self.device), + self.reward[sample_idxs].to(self.device), + self.next_state[sample_idxs].to(self.device), + self.done[sample_idxs].to(self.device) + ) + return batch, weights, tree_idxs + + def update_priorities(self, data_idxs, priorities): + if isinstance(priorities, torch.Tensor): + priorities = priorities.detach().cpu().numpy() + + for data_idx, priority in zip(data_idxs, priorities): + priority = (priority + self.eps) ** self.alpha + self.tree.update(data_idx, priority) + self.max_priority = max(self.max_priority, priority) \ No newline at end of file diff --git a/zeta/rl/PrioritizedSequenceReplayBuffer.py b/zeta/rl/PrioritizedSequenceReplayBuffer.py new file mode 100644 index 00000000..8a9de10e --- /dev/null +++ b/zeta/rl/PrioritizedSequenceReplayBuffer.py @@ -0,0 +1,112 @@ +from sumtree import SumTree +import torch +import random + +class PrioritizedSequenceReplayBuffer: + def __init__(self,state_size,action_size,buffer_size,device,eps=1e-5,alpha=0.1,beta=0.1, + decay_window=5, + decay_coff=0.4, + pre_priority=0.7): + self.tree = SumTree(data_size=buffer_size) + + # PESR params + self.eps = eps + self.alpha = alpha + self.beta = beta + self.max_priority = 1. + self.decay_window = decay_window + self.decay_coff = decay_coff + self.pre_priority = pre_priority + + # buffer params + self.state = torch.empty(buffer_size, state_size, dtype=torch.float) + self.action = torch.empty(buffer_size, action_size, dtype=torch.float) + self.reward = torch.empty(buffer_size, dtype=torch.float) + self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float) + self.done = torch.empty(buffer_size, dtype=torch.uint8) + + self.count = 0 + self.real_size = 0 + self.size = buffer_size + + # device + self.device = device + + def add(self, transition): + state, action, reward, next_state, done = transition + + # store transition index with maximum priority in sum tree + self.tree.add(self.max_priority, self.count) + + # store transition in the buffer + self.state[self.count] = torch.as_tensor(state) + self.action[self.count] = torch.as_tensor(action) + self.reward[self.count] = torch.as_tensor(reward) + self.next_state[self.count] = torch.as_tensor(next_state) + self.done[self.count] = torch.as_tensor(done) + + # update counters + self.count = (self.count + 1) % self.size + self.real_size = min(self.size, self.real_size + 1) + + def sample(self,batch_size): + assert self.real_size >= batch_size, "buffer contains less samples than batch size" + + sample_idxs, tree_idxs = [], [] + priorities = torch.empty(batch_size, 1, dtype=torch.float) + + segment = self.tree.total_priority / batch_size + for i in range(batch_size): + a, b = segment * i, segment * (i + 1) + + cumsum = random.uniform(a, b) + # sample_idx is a sample index in buffer, needed further to sample actual transitions + # tree_idx is a index of a sample in the tree, needed further to update priorities + tree_idx, priority, sample_idx = self.tree.get(cumsum) + + priorities[i] = priority + tree_idxs.append(tree_idx) + sample_idxs.append(sample_idx) + """ + Note: + The priorities stored in sumtree are all times alpha + """ + probs = priorities / self.tree.total_priority + weights = (self.real_size * probs) ** -self.beta + weights = weights / weights.max() + batch = ( + self.state[sample_idxs].to(self.device), + self.action[sample_idxs].to(self.device), + self.reward[sample_idxs].to(self.device), + self.next_state[sample_idxs].to(self.device), + self.done[sample_idxs].to(self.device) + ) + return batch, weights, tree_idxs + + def update_priorities(self,data_idxs,abs_td_errors): + """ + when we get the TD-error, we should update the transition priority p_j + And update decay_window's transition priorities + """ + if isinstance(abs_td_errors,torch.Tensor): + abs_td_errors = abs_td_errors.detach().cpu().numpy() + + for data_idx, td_error in zip(data_idxs,abs_td_errors): + # first update the batch: p_j + # p_j <- max{|delta_j| + eps, pre_priority * p_j} + old_priority = self.pre_priority * self.tree.nodes[data_idx + self.tree.size - 1] + priority = (td_error + self.eps) ** self.alpha + priority = max(priority,old_priority) + self.tree.update(data_idx,priority) + self.max_priority = max(self.max_priority,priority) + + # And then apply decay + if self.count >= self.decay_window: + # count points to the next position + # count means the idx in the buffer and number of transition + for i in reversed(range(self.decay_window)): + idx = (self.count - i - 1) % self.size + decayed_priority = priority * (self.decay_coff ** (i + 1)) + tree_idx = idx + self.tree.size - 1 + existing_priority = self.tree.nodes[tree_idx] + self.tree.update(idx,max(decayed_priority,existing_priority)) \ No newline at end of file diff --git a/zeta/rl/sumtree.py b/zeta/rl/sumtree.py new file mode 100644 index 00000000..c51805a3 --- /dev/null +++ b/zeta/rl/sumtree.py @@ -0,0 +1,98 @@ +class SumTree: + def __init__(self, size): + self.nodes = [0] * (2 * size - 1) + self.data = [None] * size + + self.size = size + self.count = 0 + self.real_size = 0 + + @property + def total(self): + return self.nodes[0] + + def propagate(self, idx, delta_value): + parent = (idx - 1) // 2 + + while parent >= 0: + self.nodes[parent] += delta_value + parent = (parent - 1) // 2 + + def update(self, data_idx, value): + idx = data_idx + self.size - 1 # child index in tree array + delta_value = value - self.nodes[idx] + + self.nodes[idx] = value + + self.propagate(idx, delta_value) + + def add(self, value, data): + self.data[self.count] = data + self.update(self.count, value) + + self.count = (self.count + 1) % self.size + self.real_size = min(self.size, self.real_size + 1) + + def get(self, cumsum): + assert cumsum <= self.total + + idx = 0 + while 2 * idx + 1 < len(self.nodes): + left, right = 2*idx + 1, 2*idx + 2 + + if cumsum <= self.nodes[left]: + idx = left + else: + idx = right + cumsum = cumsum - self.nodes[left] + + data_idx = idx - self.size + 1 + + return data_idx, self.nodes[idx], self.data[data_idx] + + def get_priority(self, data_idx): + tree_idx = data_idx + self.size - 1 + return self.nodes[tree_idx] + + + def __repr__(self): + return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})" + + +# # Test the sum tree +# if __name__ == '__main__': +# # Assuming the SumTree class definition is available + +# # Function to print the state of the tree for easier debugging +# def print_tree(tree): +# print("Tree Total:", tree.total) +# print("Tree Nodes:", tree.nodes) +# print("Tree Data:", tree.data) +# print() + +# # Create a SumTree instance +# tree_size = 5 +# tree = SumTree(tree_size) + +# # Add some data with initial priorities +# print("Adding data to the tree...") +# for i in range(tree_size): +# data = f"Data-{i}" +# priority = i + 1 # Priority is just a simple increasing number for this test +# tree.add(priority, data) +# print_tree(tree) + +# # Update priority of a data item +# print("Updating priority...") +# update_index = 2 # For example, update the priority of the third item +# new_priority = 10 +# tree.update(update_index, new_priority) +# print_tree(tree) + +# # Retrieve data based on cumulative sum +# print("Retrieving data based on cumulative sum...") +# cumulative_sums = [5, 15, 20] # Test with different cumulative sums +# for cumsum in cumulative_sums: +# idx, node_value, data = tree.get(cumsum) +# print(f"Cumulative Sum: {cumsum} -> Retrieved: {data} with Priority: {node_value}") +# print()