From c74c1d431012f7df1c688cd6fff0864d3183f7bb Mon Sep 17 00:00:00 2001 From: Kaya Celebi Date: Sun, 8 Sep 2024 14:16:17 -0400 Subject: [PATCH] Implemented missing q learning func, added util.py --- src/machine_learning/learning_agents.py | 2 +- src/machine_learning/q_learning_agents.py | 29 ++- src/machine_learning/util.py | 216 ++++++++++++++++++++++ 3 files changed, 240 insertions(+), 7 deletions(-) create mode 100644 src/machine_learning/util.py diff --git a/src/machine_learning/learning_agents.py b/src/machine_learning/learning_agents.py index 5429b9e..9f64f20 100644 --- a/src/machine_learning/learning_agents.py +++ b/src/machine_learning/learning_agents.py @@ -101,7 +101,7 @@ def is_in_testing(self): ''' - actionFn: Function which takes a state and returns the list of legal actions + action_func: Function which takes a state and returns the list of legal actions alpha - learning rate epsilon - exploration rate diff --git a/src/machine_learning/q_learning_agents.py b/src/machine_learning/q_learning_agents.py index 480da05..19e63d2 100644 --- a/src/machine_learning/q_learning_agents.py +++ b/src/machine_learning/q_learning_agents.py @@ -4,6 +4,7 @@ sys.path.append('../') from machine_learning.learning_agents import * +import machine_learning.util as util from path_handler import PathHandler as PH import sim.automata as atm from sim.rules import Rules @@ -14,24 +15,25 @@ class QLearningAgent(ReinforcementAgent): def __init__(self, **args): ReinforcementAgent.__init__(self, **args) - self.q_values = util.Counter()? + # how are we going to implement this counter obj + self.q_values = util.Counter() def get_Q_value(self, state, action): - """ + ''' Returns Q(state,action) Should return 0.0 if we have never seen a state or the Q node value otherwise - """ + ''' return self.q_values[(state, action)] def compute_value_from_Q_values(self, state): - """ + ''' Returns max_action Q(state,action) where the max is over legal actions. Note that if there are no legal actions, which is the case at the terminal state, you should return a value of 0.0. - """ + ''' legal_actions = self.get_legal_actions(state) @@ -39,7 +41,7 @@ def compute_value_from_Q_values(self, state): if len(legal_actions) == 0: return 0.0 - return max([self.getQValue(state, a) for a in legal_actions]) + return max([self.get_Q_value(state, a) for a in legal_actions]) def compute_action_from_Q_values(self, state): @@ -68,3 +70,18 @@ def get_action(self, state): return random.choice(legal_actions) return self.compute_action_from_Q_values(state) + + # update our q values here + def update(self, state, action, next_state, reward): + # from value estmation parent.parent + NSQ = self.get_value(next_state) + + + self.q_values[(state, action)] = self.get_Q_value(state, action) + self.alpha * (reward + self.discount*NSQ - self.get_Q_value(state, action)) + + + + + + + diff --git a/src/machine_learning/util.py b/src/machine_learning/util.py new file mode 100644 index 0000000..e07d3d2 --- /dev/null +++ b/src/machine_learning/util.py @@ -0,0 +1,216 @@ +class Counter(dict): + """ + A counter keeps track of counts for a set of keys. + + The counter class is an extension of the standard python + dictionary type. It is specialized to have number values + (integers or floats), and includes a handful of additional + functions to ease the task of counting data. In particular, + all keys are defaulted to have value 0. Using a dictionary: + + a = {} + print a['test'] + + would give an error, while the Counter class analogue: + + >>> a = Counter() + >>> print a['test'] + 0 + + returns the default 0 value. Note that to reference a key + that you know is contained in the counter, + you can still use the dictionary syntax: + + >>> a = Counter() + >>> a['test'] = 2 + >>> print a['test'] + 2 + + This is very useful for counting things without initializing their counts, + see for example: + + >>> a['blah'] += 1 + >>> print a['blah'] + 1 + + The counter also includes additional functionality useful in implementing + the classifiers for this assignment. Two counters can be added, + subtracted or multiplied together. See below for details. They can + also be normalized and their total count and arg max can be extracted. + """ + + def __getitem__(self, idx): + self.setdefault(idx, 0) + return dict.__getitem__(self, idx) + + def incrementAll(self, keys, count): + """ + Increments all elements of keys by the same count. + + >>> a = Counter() + >>> a.incrementAll(['one','two', 'three'], 1) + >>> a['one'] + 1 + >>> a['two'] + 1 + """ + for key in keys: + self[key] += count + + def argMax(self): + """ + Returns the key with the highest value. + """ + if len(list(self.keys())) == 0: + return None + all = list(self.items()) + values = [x[1] for x in all] + maxIndex = values.index(max(values)) + return all[maxIndex][0] + + def sortedKeys(self): + """ + Returns a list of keys sorted by their values. Keys + with the highest values will appear first. + + >>> a = Counter() + >>> a['first'] = -2 + >>> a['second'] = 4 + >>> a['third'] = 1 + >>> a.sortedKeys() + ['second', 'third', 'first'] + """ + sortedItems = list(self.items()) + + def compare(x, y): return sign(y[1] - x[1]) + sortedItems.sort(cmp=compare) + return [x[0] for x in sortedItems] + + def totalCount(self): + """ + Returns the sum of counts for all keys. + """ + return sum(self.values()) + + def normalize(self): + """ + Edits the counter such that the total count of all + keys sums to 1. The ratio of counts for all keys + will remain the same. Note that normalizing an empty + Counter will result in an error. + """ + total = float(self.totalCount()) + if total == 0: + return + for key in list(self.keys()): + self[key] = self[key] / total + + def divideAll(self, divisor): + """ + Divides all counts by divisor + """ + divisor = float(divisor) + for key in self: + self[key] /= divisor + + def copy(self): + """ + Returns a copy of the counter + """ + return Counter(dict.copy(self)) + + def __mul__(self, y): + """ + Multiplying two counters gives the dot product of their vectors where + each unique label is a vector element. + + >>> a = Counter() + >>> b = Counter() + >>> a['first'] = -2 + >>> a['second'] = 4 + >>> b['first'] = 3 + >>> b['second'] = 5 + >>> a['third'] = 1.5 + >>> a['fourth'] = 2.5 + >>> a * b + 14 + """ + sum = 0 + x = self + if len(x) > len(y): + x, y = y, x + for key in x: + if key not in y: + continue + sum += x[key] * y[key] + return sum + + def __radd__(self, y): + """ + Adding another counter to a counter increments the current counter + by the values stored in the second counter. + + >>> a = Counter() + >>> b = Counter() + >>> a['first'] = -2 + >>> a['second'] = 4 + >>> b['first'] = 3 + >>> b['third'] = 1 + >>> a += b + >>> a['first'] + 1 + """ + for key, value in list(y.items()): + self[key] += value + + def __add__(self, y): + """ + Adding two counters gives a counter with the union of all keys and + counts of the second added to counts of the first. + + >>> a = Counter() + >>> b = Counter() + >>> a['first'] = -2 + >>> a['second'] = 4 + >>> b['first'] = 3 + >>> b['third'] = 1 + >>> (a + b)['first'] + 1 + """ + addend = Counter() + for key in self: + if key in y: + addend[key] = self[key] + y[key] + else: + addend[key] = self[key] + for key in y: + if key in self: + continue + addend[key] = y[key] + return addend + + def __sub__(self, y): + """ + Subtracting a counter from another gives a counter with the union of all keys and + counts of the second subtracted from counts of the first. + + >>> a = Counter() + >>> b = Counter() + >>> a['first'] = -2 + >>> a['second'] = 4 + >>> b['first'] = 3 + >>> b['third'] = 1 + >>> (a - b)['first'] + -5 + """ + addend = Counter() + for key in self: + if key in y: + addend[key] = self[key] - y[key] + else: + addend[key] = self[key] + for key in y: + if key in self: + continue + addend[key] = -1 * y[key] + return addend