forked from alexhernandezgarcia/ActiveLearningPipeline
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreplay_buffer.py
100 lines (81 loc) · 2.79 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Adapted from https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
from collections import namedtuple
import random
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 10
# Definition needed to store memory replay in pickle
Query_Transition = namedtuple(
"Transition",
("model_state", "action_state", "next_model_state", "next_action_state", "reward", "terminal"),
)
Parameter_Transition = namedtuple(
"Transition",
("model_state", "action", "next_model_state", "reward", "terminal"),
)
class QuerySelectionReplayMemory(object):
"""
Class that encapsulates the experience replay buffer, the push and sampling method
"""
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(
self, model_state, action_state, next_model_state, next_action_state, reward, terminal
):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = None
self.memory[self.position] = Query_Transition(
model_state, action_state, next_model_state, next_action_state, reward, terminal
)
self.position = (self.position + 1) % self.capacity
del model_state
del action_state
del next_model_state
del next_action_state
del terminal
del reward
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class ParameterUpdateReplayMemory(object):
"""
Class that encapsulates the experience replay buffer, the push and sampling method
"""
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def __len__(self):
return len(self.memory)
def push(
self, model_state, action, next_model_state, reward, terminal
):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = None
#self.memory[self.position] = Parameter_Transition(
# model_state.cpu(), action, next_model_state.cpu(), reward, terminal
#)
self.memory[self.position] = {
"model_state":model_state.cpu(),
"action":action,
"next_model_state":next_model_state.cpu(),
"reward":reward,
"terminal":terminal
}
self.position = (self.position + 1) % self.capacity
del model_state
del action
del next_model_state
del terminal
del reward
def sample(self, batch_size):
return random.sample(list(self.memory), batch_size)
def __len__(self):
return len(self.memory)