Skip to content

Commit

Permalink
Implemented missing q learning func, added util.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelebi committed Sep 8, 2024
1 parent 25ef8ea commit c74c1d4
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/machine_learning/learning_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def is_in_testing(self):


'''
actionFn: Function which takes a state and returns the list of legal actions
action_func: Function which takes a state and returns the list of legal actions
alpha - learning rate
epsilon - exploration rate
Expand Down
29 changes: 23 additions & 6 deletions src/machine_learning/q_learning_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
sys.path.append('../')

from machine_learning.learning_agents import *
import machine_learning.util as util
from path_handler import PathHandler as PH
import sim.automata as atm
from sim.rules import Rules
Expand All @@ -14,32 +15,33 @@ class QLearningAgent(ReinforcementAgent):
def __init__(self, **args):
ReinforcementAgent.__init__(self, **args)

self.q_values = util.Counter()?
# how are we going to implement this counter obj
self.q_values = util.Counter()


def get_Q_value(self, state, action):
"""
'''
Returns Q(state,action)
Should return 0.0 if we have never seen a state
or the Q node value otherwise
"""
'''
return self.q_values[(state, action)]

def compute_value_from_Q_values(self, state):
"""
'''
Returns max_action Q(state,action)
where the max is over legal actions. Note that if
there are no legal actions, which is the case at the
terminal state, you should return a value of 0.0.
"""
'''

legal_actions = self.get_legal_actions(state)

#Terminal
if len(legal_actions) == 0:
return 0.0

return max([self.getQValue(state, a) for a in legal_actions])
return max([self.get_Q_value(state, a) for a in legal_actions])


def compute_action_from_Q_values(self, state):
Expand Down Expand Up @@ -68,3 +70,18 @@ def get_action(self, state):
return random.choice(legal_actions)
return self.compute_action_from_Q_values(state)


# update our q values here
def update(self, state, action, next_state, reward):
# from value estmation parent.parent
NSQ = self.get_value(next_state)


self.q_values[(state, action)] = self.get_Q_value(state, action) + self.alpha * (reward + self.discount*NSQ - self.get_Q_value(state, action))







216 changes: 216 additions & 0 deletions src/machine_learning/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
class Counter(dict):
"""
A counter keeps track of counts for a set of keys.
The counter class is an extension of the standard python
dictionary type. It is specialized to have number values
(integers or floats), and includes a handful of additional
functions to ease the task of counting data. In particular,
all keys are defaulted to have value 0. Using a dictionary:
a = {}
print a['test']
would give an error, while the Counter class analogue:
>>> a = Counter()
>>> print a['test']
0
returns the default 0 value. Note that to reference a key
that you know is contained in the counter,
you can still use the dictionary syntax:
>>> a = Counter()
>>> a['test'] = 2
>>> print a['test']
2
This is very useful for counting things without initializing their counts,
see for example:
>>> a['blah'] += 1
>>> print a['blah']
1
The counter also includes additional functionality useful in implementing
the classifiers for this assignment. Two counters can be added,
subtracted or multiplied together. See below for details. They can
also be normalized and their total count and arg max can be extracted.
"""

def __getitem__(self, idx):
self.setdefault(idx, 0)
return dict.__getitem__(self, idx)

def incrementAll(self, keys, count):
"""
Increments all elements of keys by the same count.
>>> a = Counter()
>>> a.incrementAll(['one','two', 'three'], 1)
>>> a['one']
1
>>> a['two']
1
"""
for key in keys:
self[key] += count

def argMax(self):
"""
Returns the key with the highest value.
"""
if len(list(self.keys())) == 0:
return None
all = list(self.items())
values = [x[1] for x in all]
maxIndex = values.index(max(values))
return all[maxIndex][0]

def sortedKeys(self):
"""
Returns a list of keys sorted by their values. Keys
with the highest values will appear first.
>>> a = Counter()
>>> a['first'] = -2
>>> a['second'] = 4
>>> a['third'] = 1
>>> a.sortedKeys()
['second', 'third', 'first']
"""
sortedItems = list(self.items())

def compare(x, y): return sign(y[1] - x[1])
sortedItems.sort(cmp=compare)
return [x[0] for x in sortedItems]

def totalCount(self):
"""
Returns the sum of counts for all keys.
"""
return sum(self.values())

def normalize(self):
"""
Edits the counter such that the total count of all
keys sums to 1. The ratio of counts for all keys
will remain the same. Note that normalizing an empty
Counter will result in an error.
"""
total = float(self.totalCount())
if total == 0:
return
for key in list(self.keys()):
self[key] = self[key] / total

def divideAll(self, divisor):
"""
Divides all counts by divisor
"""
divisor = float(divisor)
for key in self:
self[key] /= divisor

def copy(self):
"""
Returns a copy of the counter
"""
return Counter(dict.copy(self))

def __mul__(self, y):
"""
Multiplying two counters gives the dot product of their vectors where
each unique label is a vector element.
>>> a = Counter()
>>> b = Counter()
>>> a['first'] = -2
>>> a['second'] = 4
>>> b['first'] = 3
>>> b['second'] = 5
>>> a['third'] = 1.5
>>> a['fourth'] = 2.5
>>> a * b
14
"""
sum = 0
x = self
if len(x) > len(y):
x, y = y, x
for key in x:
if key not in y:
continue
sum += x[key] * y[key]
return sum

def __radd__(self, y):
"""
Adding another counter to a counter increments the current counter
by the values stored in the second counter.
>>> a = Counter()
>>> b = Counter()
>>> a['first'] = -2
>>> a['second'] = 4
>>> b['first'] = 3
>>> b['third'] = 1
>>> a += b
>>> a['first']
1
"""
for key, value in list(y.items()):
self[key] += value

def __add__(self, y):
"""
Adding two counters gives a counter with the union of all keys and
counts of the second added to counts of the first.
>>> a = Counter()
>>> b = Counter()
>>> a['first'] = -2
>>> a['second'] = 4
>>> b['first'] = 3
>>> b['third'] = 1
>>> (a + b)['first']
1
"""
addend = Counter()
for key in self:
if key in y:
addend[key] = self[key] + y[key]
else:
addend[key] = self[key]
for key in y:
if key in self:
continue
addend[key] = y[key]
return addend

def __sub__(self, y):
"""
Subtracting a counter from another gives a counter with the union of all keys and
counts of the second subtracted from counts of the first.
>>> a = Counter()
>>> b = Counter()
>>> a['first'] = -2
>>> a['second'] = 4
>>> b['first'] = 3
>>> b['third'] = 1
>>> (a - b)['first']
-5
"""
addend = Counter()
for key in self:
if key in y:
addend[key] = self[key] - y[key]
else:
addend[key] = self[key]
for key in y:
if key in self:
continue
addend[key] = -1 * y[key]
return addend

0 comments on commit c74c1d4

Please sign in to comment.