Skip to content

Commit

Permalink
Merge pull request #35 from accuracy-maker/master
Browse files Browse the repository at this point in the history
add sumtree,PER and PESR
  • Loading branch information
kyegomez authored Dec 24, 2023
2 parents 288b482 + 07541b2 commit 1bb7da1
Show file tree
Hide file tree
Showing 6 changed files with 476 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/rl/test_prioritizedreplybuffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import random
import torch
from zeta.rl.PrioritizedReplayBuffer import PrioritizedReplayBuffer, SumTree # Replace 'your_module' with the actual module where classes are defined

@pytest.fixture
def replay_buffer():
state_size = 4
action_size = 2
buffer_size = 100
device = torch.device("cpu")
return PrioritizedReplayBuffer(state_size, action_size, buffer_size, device)

def test_initialization(replay_buffer):
assert replay_buffer.eps == 1e-2
assert replay_buffer.alpha == 0.1
assert replay_buffer.beta == 0.1
assert replay_buffer.max_priority == 1.0
assert replay_buffer.count == 0
assert replay_buffer.real_size == 0
assert replay_buffer.size == 100
assert replay_buffer.device == torch.device("cpu")

def test_add(replay_buffer):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
replay_buffer.add(transition)
assert replay_buffer.count == 1
assert replay_buffer.real_size == 1

def test_sample(replay_buffer):
for i in range(10):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
replay_buffer.add(transition)

batch, weights, tree_idxs = replay_buffer.sample(5)
assert len(batch) == 5
assert len(weights) == 5
assert len(tree_idxs) == 5

def test_update_priorities(replay_buffer):
for i in range(10):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
replay_buffer.add(transition)

batch, weights, tree_idxs = replay_buffer.sample(5)
new_priorities = torch.rand(5)
replay_buffer.update_priorities(tree_idxs, new_priorities)

def test_sample_with_invalid_batch_size(replay_buffer):
with pytest.raises(AssertionError):
replay_buffer.sample(101)

def test_add_with_max_size(replay_buffer):
for i in range(100):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
replay_buffer.add(transition)

assert replay_buffer.count == 0
assert replay_buffer.real_size == 100

# Additional tests for edge cases, exceptions, and more scenarios can be added as needed.
64 changes: 64 additions & 0 deletions tests/rl/test_prioritizedsequencereplybuffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
import random
import torch
from zeta.rl.PrioritizedSequenceReplayBuffer import PrioritizedSequenceReplayBuffer, SumTree # Replace 'your_module' with the actual module where classes are defined

@pytest.fixture
def replay_buffer():
state_size = 4
action_size = 2
buffer_size = 100
device = torch.device("cpu")
return PrioritizedSequenceReplayBuffer(state_size, action_size, buffer_size, device)

def test_initialization(replay_buffer):
assert replay_buffer.eps == 1e-5
assert replay_buffer.alpha == 0.1
assert replay_buffer.beta == 0.1
assert replay_buffer.max_priority == 1.0
assert replay_buffer.decay_window == 5
assert replay_buffer.decay_coff == 0.4
assert replay_buffer.pre_priority == 0.7
assert replay_buffer.count == 0
assert replay_buffer.real_size == 0
assert replay_buffer.size == 100
assert replay_buffer.device == torch.device("cpu")

def test_add(replay_buffer):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
replay_buffer.add(transition)
assert replay_buffer.count == 1
assert replay_buffer.real_size == 1

def test_sample(replay_buffer):
for i in range(10):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
replay_buffer.add(transition)

batch, weights, tree_idxs = replay_buffer.sample(5)
assert len(batch) == 5
assert len(weights) == 5
assert len(tree_idxs) == 5

def test_update_priorities(replay_buffer):
for i in range(10):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
replay_buffer.add(transition)

batch, weights, tree_idxs = replay_buffer.sample(5)
new_priorities = torch.rand(5)
replay_buffer.update_priorities(tree_idxs, new_priorities)

def test_sample_with_invalid_batch_size(replay_buffer):
with pytest.raises(AssertionError):
replay_buffer.sample(101)

def test_add_with_max_size(replay_buffer):
for i in range(100):
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False)
replay_buffer.add(transition)

assert replay_buffer.count == 0
assert replay_buffer.real_size == 100

# Additional tests for edge cases, exceptions, and more scenarios can be added as needed.
56 changes: 56 additions & 0 deletions tests/rl/test_sumtree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
from zeta.rl.sumtree import SumTree # Replace 'your_module' with the actual module where SumTree is defined

# Fixture for initializing SumTree instances with a given size
@pytest.fixture
def sum_tree():
size = 10 # You can change the size as needed
return SumTree(size)

# Basic tests
def test_initialization(sum_tree):
assert sum_tree.size == 10
assert sum_tree.count == 0
assert sum_tree.real_size == 0
assert sum_tree.total == 0

def test_update_and_get(sum_tree):
sum_tree.add(5, "data1")
assert sum_tree.total == 5
data_idx, priority, data = sum_tree.get(5)
assert data_idx == 0
assert priority == 5
assert data == "data1"

def test_add_overflow(sum_tree):
for i in range(15):
sum_tree.add(i, f"data{i}")
assert sum_tree.count == 5
assert sum_tree.real_size == 10

# Parameterized testing for various scenarios
@pytest.mark.parametrize("values, expected_total", [
([1, 2, 3, 4, 5], 15),
([10, 20, 30, 40, 50], 150),
])
def test_multiple_updates(sum_tree, values, expected_total):
for value in values:
sum_tree.add(value, None)
assert sum_tree.total == expected_total

# Exception testing
def test_get_with_invalid_cumsum(sum_tree):
with pytest.raises(AssertionError):
sum_tree.get(20)

# More tests for specific methods
def test_get_priority(sum_tree):
sum_tree.add(10, "data1")
priority = sum_tree.get_priority(0)
assert priority == 10

def test_repr(sum_tree):
expected_repr = f"SumTree(nodes={sum_tree.nodes}, data={sum_tree.data})"
assert repr(sum_tree) == expected_repr

# More test cases can be added as needed
85 changes: 85 additions & 0 deletions zeta/rl/PrioritizedReplayBuffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from sumtree import SumTree
import torch
import random

class PrioritizedReplayBuffer:
def __init__(self, state_size, action_size, buffer_size, device, eps=1e-2, alpha=0.1, beta=0.1):
self.tree = SumTree(size=buffer_size)


self.eps = eps
self.alpha = alpha
self.beta = beta
self.max_priority = 1.


self.state = torch.empty(buffer_size, state_size, dtype=torch.float)
self.action = torch.empty(buffer_size, action_size, dtype=torch.float)
self.reward = torch.empty(buffer_size, dtype=torch.float)
self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float)
self.done = torch.empty(buffer_size, dtype=torch.uint8)

self.count = 0
self.real_size = 0
self.size = buffer_size

# device
self.device = device

def add(self, transition):
state, action, reward, next_state, done = transition


self.tree.add(self.max_priority, self.count)

self.state[self.count] = torch.as_tensor(state)
self.action[self.count] = torch.as_tensor(action)
self.reward[self.count] = torch.as_tensor(reward)
self.next_state[self.count] = torch.as_tensor(next_state)
self.done[self.count] = torch.as_tensor(done)


self.count = (self.count + 1) % self.size
self.real_size = min(self.size, self.real_size + 1)

def sample(self, batch_size):
assert self.real_size >= batch_size, "buffer contains less samples than batch size"

sample_idxs, tree_idxs = [], []
priorities = torch.empty(batch_size, 1, dtype=torch.float)


segment = self.tree.total / batch_size
for i in range(batch_size):
a, b = segment * i, segment * (i + 1)

cumsum = random.uniform(a, b)

tree_idx, priority, sample_idx = self.tree.get(cumsum)

priorities[i] = priority
tree_idxs.append(tree_idx)
sample_idxs.append(sample_idx)

probs = priorities / self.tree.total

weights = (self.real_size * probs) ** -self.beta

weights = weights / weights.max()
batch = (
self.state[sample_idxs].to(self.device),
self.action[sample_idxs].to(self.device),
self.reward[sample_idxs].to(self.device),
self.next_state[sample_idxs].to(self.device),
self.done[sample_idxs].to(self.device)
)
return batch, weights, tree_idxs

def update_priorities(self, data_idxs, priorities):
if isinstance(priorities, torch.Tensor):
priorities = priorities.detach().cpu().numpy()

for data_idx, priority in zip(data_idxs, priorities):
priority = (priority + self.eps) ** self.alpha
self.tree.update(data_idx, priority)
self.max_priority = max(self.max_priority, priority)
112 changes: 112 additions & 0 deletions zeta/rl/PrioritizedSequenceReplayBuffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from sumtree import SumTree
import torch
import random

class PrioritizedSequenceReplayBuffer:
def __init__(self,state_size,action_size,buffer_size,device,eps=1e-5,alpha=0.1,beta=0.1,
decay_window=5,
decay_coff=0.4,
pre_priority=0.7):
self.tree = SumTree(data_size=buffer_size)

# PESR params
self.eps = eps
self.alpha = alpha
self.beta = beta
self.max_priority = 1.
self.decay_window = decay_window
self.decay_coff = decay_coff
self.pre_priority = pre_priority

# buffer params
self.state = torch.empty(buffer_size, state_size, dtype=torch.float)
self.action = torch.empty(buffer_size, action_size, dtype=torch.float)
self.reward = torch.empty(buffer_size, dtype=torch.float)
self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float)
self.done = torch.empty(buffer_size, dtype=torch.uint8)

self.count = 0
self.real_size = 0
self.size = buffer_size

# device
self.device = device

def add(self, transition):
state, action, reward, next_state, done = transition

# store transition index with maximum priority in sum tree
self.tree.add(self.max_priority, self.count)

# store transition in the buffer
self.state[self.count] = torch.as_tensor(state)
self.action[self.count] = torch.as_tensor(action)
self.reward[self.count] = torch.as_tensor(reward)
self.next_state[self.count] = torch.as_tensor(next_state)
self.done[self.count] = torch.as_tensor(done)

# update counters
self.count = (self.count + 1) % self.size
self.real_size = min(self.size, self.real_size + 1)

def sample(self,batch_size):
assert self.real_size >= batch_size, "buffer contains less samples than batch size"

sample_idxs, tree_idxs = [], []
priorities = torch.empty(batch_size, 1, dtype=torch.float)

segment = self.tree.total_priority / batch_size
for i in range(batch_size):
a, b = segment * i, segment * (i + 1)

cumsum = random.uniform(a, b)
# sample_idx is a sample index in buffer, needed further to sample actual transitions
# tree_idx is a index of a sample in the tree, needed further to update priorities
tree_idx, priority, sample_idx = self.tree.get(cumsum)

priorities[i] = priority
tree_idxs.append(tree_idx)
sample_idxs.append(sample_idx)
"""
Note:
The priorities stored in sumtree are all times alpha
"""
probs = priorities / self.tree.total_priority
weights = (self.real_size * probs) ** -self.beta
weights = weights / weights.max()
batch = (
self.state[sample_idxs].to(self.device),
self.action[sample_idxs].to(self.device),
self.reward[sample_idxs].to(self.device),
self.next_state[sample_idxs].to(self.device),
self.done[sample_idxs].to(self.device)
)
return batch, weights, tree_idxs

def update_priorities(self,data_idxs,abs_td_errors):
"""
when we get the TD-error, we should update the transition priority p_j
And update decay_window's transition priorities
"""
if isinstance(abs_td_errors,torch.Tensor):
abs_td_errors = abs_td_errors.detach().cpu().numpy()

for data_idx, td_error in zip(data_idxs,abs_td_errors):
# first update the batch: p_j
# p_j <- max{|delta_j| + eps, pre_priority * p_j}
old_priority = self.pre_priority * self.tree.nodes[data_idx + self.tree.size - 1]
priority = (td_error + self.eps) ** self.alpha
priority = max(priority,old_priority)
self.tree.update(data_idx,priority)
self.max_priority = max(self.max_priority,priority)

# And then apply decay
if self.count >= self.decay_window:
# count points to the next position
# count means the idx in the buffer and number of transition
for i in reversed(range(self.decay_window)):
idx = (self.count - i - 1) % self.size
decayed_priority = priority * (self.decay_coff ** (i + 1))
tree_idx = idx + self.tree.size - 1
existing_priority = self.tree.nodes[tree_idx]
self.tree.update(idx,max(decayed_priority,existing_priority))
Loading

0 comments on commit 1bb7da1

Please sign in to comment.