From 25ef8ea66c1686946c94a71b2ed86a39d40981bb Mon Sep 17 00:00:00 2001 From: Kaya Celebi Date: Sun, 8 Sep 2024 14:05:49 -0400 Subject: [PATCH 1/8] Implemented episode stops for learning agent module --- sim | 76 ------------------------- src/machine_learning/learning_agents.py | 73 +++++++++++++++++------- 2 files changed, 52 insertions(+), 97 deletions(-) delete mode 100644 sim diff --git a/sim b/sim deleted file mode 100644 index 7af8dfd..0000000 --- a/sim +++ /dev/null @@ -1,76 +0,0 @@ -import numpy as np - -''' - Automata is state by state - Note: might be better to do this in C for speedup - - idea -- create reinforcement learning model - - we have state, action - - model learns to add points to the CA state to achieve some outcome - - consider: ask agent to create a certain state, or some qualification using k moves - - how does it do it - - - we need pattern detection - - we need cycle detection - - can we transform each image to a graph... - - each state to a graph? - - - some web server where you can pkay / analyze differnet automata setup - - have the interactive code etc - - need to port it all to JS -''' - - - -''' - Brute force update -- O(n^2) - - NOTE: this is *not* updating in-place -''' -def update(state, rule): - state = state_fix(state) - - x, y = state.shape - new_state = np.zeros(state.shape, dtype = int) - - for i in range(x): - for j in range(y): - neighbors = get_neighbors(i, j, shape = state.shape) - new_state[i, j] = rule(neighbors, cell = state[i, j], state = state) - - return new_state - -def np_update(state): - x, y = state.shape - new_state = np.zeros(state.shape) - - # write a lambda to generate the update more efficiently - return - -''' - Returns all valid neighbors -''' -def get_neighbors(i, j, shape): - # list comp generates all neighbors including center. If center or invalid neighbor, - # does i-10, j-10 as coord to remove in next step - neighbors = np.reshape([[[i + u, j + v] if in_range(i + u, j + v, shape = shape) else [i - 10, j - 10] for u in [-1, 0, 1]] for v in [-1, 0, 1]], (-1, 2)) - return neighbors[~np.all(np.logical_or(neighbors == [i, j], neighbors == [i - 10, j - 10]), axis = 1)] # make sure to exlude center and not in-range values - -''' - Check the provided coord is in range of the matrix -''' -def in_range(i, j, shape): - if i < shape[0] and i > -1 and j > -1 and j < shape[1]: - return True - return False - -def get_random_state(shape): - return np.random.randint(0, 2, size = shape) - -def state_fix(state): - if type(state) != np.array: - state = np.array(state, dtype = int) - if len(state.shape) == 1: - state = state.reshape((1, -1)) - - return state \ No newline at end of file diff --git a/src/machine_learning/learning_agents.py b/src/machine_learning/learning_agents.py index 09a50f8..5429b9e 100644 --- a/src/machine_learning/learning_agents.py +++ b/src/machine_learning/learning_agents.py @@ -20,28 +20,28 @@ def get_action(self, state): class ValueEstimationAgent(Agent): """ - Abstract agent which assigns values to (state,action) - Q-Values for an environment. As well as a value to a - state and a policy given respectively by, - - V(s) = max_{a in actions} Q(s,a) - policy(s) = arg_max_{a in actions} Q(s,a) - - Both ValueIterationAgent and QLearningAgent inherit - from this agent. While a ValueIterationAgent has - a model of the environment via a MarkovDecisionProcess - (see mdp.py) that is used to estimate Q-Values before - ever actually acting, the QLearningAgent estimates - Q-Values while acting in the environment. - """ + Abstract agent which assigns values to (state,action) + Q-Values for an environment. As well as a value to a + state and a policy given respectively by, + + V(s) = max_{a in actions} Q(s,a) + policy(s) = arg_max_{a in actions} Q(s,a) + + Both ValueIterationAgent and QLearningAgent inherit + from this agent. While a ValueIterationAgent has + a model of the environment via a MarkovDecisionProcess + (see mdp.py) that is used to estimate Q-Values before + ever actually acting, the QLearningAgent estimates + Q-Values while acting in the environment. + """ def __init__(self, alpha, epsilon, gamma, num_training): ''' - alpha - learning rate - epsilon - exploration rate - gamma - discount factor - num_training - number of training episodes, i.e. no learning after these many episodes - ''' + alpha - learning rate + epsilon - exploration rate + gamma - discount factor + num_training - number of training episodes, i.e. no learning after these many episodes + ''' self.alpha = alpha self.epsilon = epsilon self.gamma = gamma @@ -71,12 +71,43 @@ def observe_transition(self, state, action, next_state, delta_reward): self.episode_rewards += delta_reward self.update(state, action, next_state, delta_reward) + ''' + Called by env when new episode is starting + ''' def start_episode(self): - ... + self.last_state = None + self.last_action = None + self.episode_rewards = 0.0 def stop_episode(self): - ... + if self.is_in_training(): + self.accum_train_rewards += self.episode_rewards + else: + self.accum_test_rewards += self.episode_rewards + + self.episodes_so_far += 1 + + # set vars for testing + if self.is_in_testing: + self.epsilon = 0.0 # no exploration + self.alpha = 0.0 # no learning + + + def is_in_training(self): + return self.episodes_so_far < self.num_training + + def is_in_testing(self): + return not self.is_in_training + + + ''' + actionFn: Function which takes a state and returns the list of legal actions + alpha - learning rate + epsilon - exploration rate + gamma - discount factor + numTraining - number of training episodes, i.e. no learning after these many episodes + ''' def __init__(self, action_func = None, num_training = 100, epsilon = 0.5, alpha = 0.5, gamma = 1): # we should never be in this position, overwrite this later if action_func is None: From c74c1d431012f7df1c688cd6fff0864d3183f7bb Mon Sep 17 00:00:00 2001 From: Kaya Celebi Date: Sun, 8 Sep 2024 14:16:17 -0400 Subject: [PATCH 2/8] Implemented missing q learning func, added util.py --- src/machine_learning/learning_agents.py | 2 +- src/machine_learning/q_learning_agents.py | 29 ++- src/machine_learning/util.py | 216 ++++++++++++++++++++++ 3 files changed, 240 insertions(+), 7 deletions(-) create mode 100644 src/machine_learning/util.py diff --git a/src/machine_learning/learning_agents.py b/src/machine_learning/learning_agents.py index 5429b9e..9f64f20 100644 --- a/src/machine_learning/learning_agents.py +++ b/src/machine_learning/learning_agents.py @@ -101,7 +101,7 @@ def is_in_testing(self): ''' - actionFn: Function which takes a state and returns the list of legal actions + action_func: Function which takes a state and returns the list of legal actions alpha - learning rate epsilon - exploration rate diff --git a/src/machine_learning/q_learning_agents.py b/src/machine_learning/q_learning_agents.py index 480da05..19e63d2 100644 --- a/src/machine_learning/q_learning_agents.py +++ b/src/machine_learning/q_learning_agents.py @@ -4,6 +4,7 @@ sys.path.append('../') from machine_learning.learning_agents import * +import machine_learning.util as util from path_handler import PathHandler as PH import sim.automata as atm from sim.rules import Rules @@ -14,24 +15,25 @@ class QLearningAgent(ReinforcementAgent): def __init__(self, **args): ReinforcementAgent.__init__(self, **args) - self.q_values = util.Counter()? + # how are we going to implement this counter obj + self.q_values = util.Counter() def get_Q_value(self, state, action): - """ + ''' Returns Q(state,action) Should return 0.0 if we have never seen a state or the Q node value otherwise - """ + ''' return self.q_values[(state, action)] def compute_value_from_Q_values(self, state): - """ + ''' Returns max_action Q(state,action) where the max is over legal actions. Note that if there are no legal actions, which is the case at the terminal state, you should return a value of 0.0. - """ + ''' legal_actions = self.get_legal_actions(state) @@ -39,7 +41,7 @@ def compute_value_from_Q_values(self, state): if len(legal_actions) == 0: return 0.0 - return max([self.getQValue(state, a) for a in legal_actions]) + return max([self.get_Q_value(state, a) for a in legal_actions]) def compute_action_from_Q_values(self, state): @@ -68,3 +70,18 @@ def get_action(self, state): return random.choice(legal_actions) return self.compute_action_from_Q_values(state) + + # update our q values here + def update(self, state, action, next_state, reward): + # from value estmation parent.parent + NSQ = self.get_value(next_state) + + + self.q_values[(state, action)] = self.get_Q_value(state, action) + self.alpha * (reward + self.discount*NSQ - self.get_Q_value(state, action)) + + + + + + + diff --git a/src/machine_learning/util.py b/src/machine_learning/util.py new file mode 100644 index 0000000..e07d3d2 --- /dev/null +++ b/src/machine_learning/util.py @@ -0,0 +1,216 @@ +class Counter(dict): + """ + A counter keeps track of counts for a set of keys. + + The counter class is an extension of the standard python + dictionary type. It is specialized to have number values + (integers or floats), and includes a handful of additional + functions to ease the task of counting data. In particular, + all keys are defaulted to have value 0. Using a dictionary: + + a = {} + print a['test'] + + would give an error, while the Counter class analogue: + + >>> a = Counter() + >>> print a['test'] + 0 + + returns the default 0 value. Note that to reference a key + that you know is contained in the counter, + you can still use the dictionary syntax: + + >>> a = Counter() + >>> a['test'] = 2 + >>> print a['test'] + 2 + + This is very useful for counting things without initializing their counts, + see for example: + + >>> a['blah'] += 1 + >>> print a['blah'] + 1 + + The counter also includes additional functionality useful in implementing + the classifiers for this assignment. Two counters can be added, + subtracted or multiplied together. See below for details. They can + also be normalized and their total count and arg max can be extracted. + """ + + def __getitem__(self, idx): + self.setdefault(idx, 0) + return dict.__getitem__(self, idx) + + def incrementAll(self, keys, count): + """ + Increments all elements of keys by the same count. + + >>> a = Counter() + >>> a.incrementAll(['one','two', 'three'], 1) + >>> a['one'] + 1 + >>> a['two'] + 1 + """ + for key in keys: + self[key] += count + + def argMax(self): + """ + Returns the key with the highest value. + """ + if len(list(self.keys())) == 0: + return None + all = list(self.items()) + values = [x[1] for x in all] + maxIndex = values.index(max(values)) + return all[maxIndex][0] + + def sortedKeys(self): + """ + Returns a list of keys sorted by their values. Keys + with the highest values will appear first. + + >>> a = Counter() + >>> a['first'] = -2 + >>> a['second'] = 4 + >>> a['third'] = 1 + >>> a.sortedKeys() + ['second', 'third', 'first'] + """ + sortedItems = list(self.items()) + + def compare(x, y): return sign(y[1] - x[1]) + sortedItems.sort(cmp=compare) + return [x[0] for x in sortedItems] + + def totalCount(self): + """ + Returns the sum of counts for all keys. + """ + return sum(self.values()) + + def normalize(self): + """ + Edits the counter such that the total count of all + keys sums to 1. The ratio of counts for all keys + will remain the same. Note that normalizing an empty + Counter will result in an error. + """ + total = float(self.totalCount()) + if total == 0: + return + for key in list(self.keys()): + self[key] = self[key] / total + + def divideAll(self, divisor): + """ + Divides all counts by divisor + """ + divisor = float(divisor) + for key in self: + self[key] /= divisor + + def copy(self): + """ + Returns a copy of the counter + """ + return Counter(dict.copy(self)) + + def __mul__(self, y): + """ + Multiplying two counters gives the dot product of their vectors where + each unique label is a vector element. + + >>> a = Counter() + >>> b = Counter() + >>> a['first'] = -2 + >>> a['second'] = 4 + >>> b['first'] = 3 + >>> b['second'] = 5 + >>> a['third'] = 1.5 + >>> a['fourth'] = 2.5 + >>> a * b + 14 + """ + sum = 0 + x = self + if len(x) > len(y): + x, y = y, x + for key in x: + if key not in y: + continue + sum += x[key] * y[key] + return sum + + def __radd__(self, y): + """ + Adding another counter to a counter increments the current counter + by the values stored in the second counter. + + >>> a = Counter() + >>> b = Counter() + >>> a['first'] = -2 + >>> a['second'] = 4 + >>> b['first'] = 3 + >>> b['third'] = 1 + >>> a += b + >>> a['first'] + 1 + """ + for key, value in list(y.items()): + self[key] += value + + def __add__(self, y): + """ + Adding two counters gives a counter with the union of all keys and + counts of the second added to counts of the first. + + >>> a = Counter() + >>> b = Counter() + >>> a['first'] = -2 + >>> a['second'] = 4 + >>> b['first'] = 3 + >>> b['third'] = 1 + >>> (a + b)['first'] + 1 + """ + addend = Counter() + for key in self: + if key in y: + addend[key] = self[key] + y[key] + else: + addend[key] = self[key] + for key in y: + if key in self: + continue + addend[key] = y[key] + return addend + + def __sub__(self, y): + """ + Subtracting a counter from another gives a counter with the union of all keys and + counts of the second subtracted from counts of the first. + + >>> a = Counter() + >>> b = Counter() + >>> a['first'] = -2 + >>> a['second'] = 4 + >>> b['first'] = 3 + >>> b['third'] = 1 + >>> (a - b)['first'] + -5 + """ + addend = Counter() + for key in self: + if key in y: + addend[key] = self[key] - y[key] + else: + addend[key] = self[key] + for key in y: + if key in self: + continue + addend[key] = -1 * y[key] + return addend From 1a622f472e05f6e5e38ac952a5ce6f3f5835863f Mon Sep 17 00:00:00 2001 From: Kaya Celebi Date: Sun, 8 Sep 2024 14:17:40 -0400 Subject: [PATCH 3/8] Adding missing get_policy and get_value to qlearning --- src/machine_learning/q_learning_agents.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/machine_learning/q_learning_agents.py b/src/machine_learning/q_learning_agents.py index 19e63d2..4d056f5 100644 --- a/src/machine_learning/q_learning_agents.py +++ b/src/machine_learning/q_learning_agents.py @@ -80,8 +80,10 @@ def update(self, state, action, next_state, reward): self.q_values[(state, action)] = self.get_Q_value(state, action) + self.alpha * (reward + self.discount*NSQ - self.get_Q_value(state, action)) + def get_policy(self, state): + return self.compute_action_from_Q_values(state) - - + def get_value(self, state): + return self.compute_value_from_Q_values(state) From aecb0fbdf59fe4e714576c1b89625f3deb917cf8 Mon Sep 17 00:00:00 2001 From: Kaya Celebi Date: Sun, 8 Sep 2024 14:23:08 -0400 Subject: [PATCH 4/8] Adding RL tests --- nb/q-learning-v0.ipynb | 125 ++++++++++++++++++++++++ src/machine_learning/learning_agents.py | 2 +- 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 nb/q-learning-v0.ipynb diff --git a/nb/q-learning-v0.ipynb b/nb/q-learning-v0.ipynb new file mode 100644 index 0000000..7e5b287 --- /dev/null +++ b/nb/q-learning-v0.ipynb @@ -0,0 +1,125 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "8fda40c4", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import sys\n", + "sys.path.append('../src')\n", + "import sim.automata as atm\n", + "import analysis.analysis as ans\n", + "import analysis.stats as stats\n", + "from sim.rules import Rules\n", + "\n", + "from machine_learning.q_learning_agents import QLearningAgent" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "373b2a7b", + "metadata": {}, + "outputs": [], + "source": [ + "# we need to define an action function to obtain the legal actions, which would involve placing\n", + "# anywhere that is not alive on the board.\n", + "\n", + "# then, how to evaluate said action\n", + "\n", + "agent = QLearningAgent(action_func = ...)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ada366a0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1, 1, 0, 1, 1, 0, 0, 0, 1, 1],\n", + " [1, 1, 0, 0, 1, 0, 1, 0, 1, 1],\n", + " [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],\n", + " [0, 0, 1, 1, 0, 1, 1, 1, 1, 1],\n", + " [0, 0, 0, 0, 0, 1, 1, 1, 0, 0],\n", + " [1, 0, 1, 1, 1, 1, 1, 1, 1, 0],\n", + " [0, 1, 0, 0, 0, 0, 1, 0, 0, 1],\n", + " [0, 1, 1, 0, 0, 1, 0, 0, 1, 1],\n", + " [0, 1, 1, 0, 0, 1, 1, 0, 1, 1],\n", + " [1, 1, 1, 0, 0, 0, 1, 1, 0, 0]])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "state = atm.get_random_state((10,10))\n", + "state" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4afa74d9", + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'numpy.ndarray' object has no attribute 'get_legal_actions'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/folders/cc/c6pxrx8n6wx77yf6_g8dvdw40000gn/T/ipykernel_14646/3427211211.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0magent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/q_learning_agents.py\u001b[0m in \u001b[0;36mget_action\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mlegal_actions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_legal_actions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# Terminal state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/learning_agents.py\u001b[0m in \u001b[0;36mget_legal_actions\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_legal_actions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 69\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mobserve_transition\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdelta_reward\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/learning_agents.py\u001b[0m in \u001b[0;36m\u001b[0;34m(state)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;31m# we should never be in this position, overwrite this later\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0maction_func\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0maction_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_legal_actions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# not possible, state not an obj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0maction_func\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'numpy.ndarray' object has no attribute 'get_legal_actions'" + ] + } + ], + "source": [ + "agent.get_action(state)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d97459e4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/machine_learning/learning_agents.py b/src/machine_learning/learning_agents.py index 9f64f20..f2bb226 100644 --- a/src/machine_learning/learning_agents.py +++ b/src/machine_learning/learning_agents.py @@ -35,7 +35,7 @@ class ValueEstimationAgent(Agent): Q-Values while acting in the environment. """ - def __init__(self, alpha, epsilon, gamma, num_training): + def __init__(self, alpha = 1.0, epsilon = 0.05, gamma = 0.8, num_training = 10): ''' alpha - learning rate epsilon - exploration rate From 09f402b0467bf61f51f045353cc6d2ce68da0364 Mon Sep 17 00:00:00 2001 From: Kaya Celebi Date: Sun, 8 Sep 2024 21:33:30 -0400 Subject: [PATCH 5/8] Added missing reward func component --- src/machine_learning/learning_agents.py | 28 ++++++++++++++++++++--- src/machine_learning/q_learning_agents.py | 7 +++--- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/machine_learning/learning_agents.py b/src/machine_learning/learning_agents.py index f2bb226..f3ca6b5 100644 --- a/src/machine_learning/learning_agents.py +++ b/src/machine_learning/learning_agents.py @@ -64,13 +64,38 @@ class ReinforcementAgent(ValueEstimationAgent): def update(self, state, action, next_state, reward): raise NotImplementedError() + + ''' + Provided custom action_func -- to be overloaded into a new class + + Provides the possible legal actions for CA agent + ''' def get_legal_actions(self, state): return self.action_func(state) + ''' + Who calls this? + ''' def observe_transition(self, state, action, next_state, delta_reward): self.episode_rewards += delta_reward self.update(state, action, next_state, delta_reward) + + ''' + At each point in the game, we observe the state we have just arrived at + and assess how that affects our score. + ''' + def observe_function(self, state): + + if ... : + reward = ... #current_reward - prev state reward :: delta_reward + self.observe_transition( + state = self.last_state, + action = self.last_action, + next_state = state, + delta_reward = reward + ) + ''' Called by env when new episode is starting ''' @@ -109,9 +134,6 @@ def is_in_testing(self): numTraining - number of training episodes, i.e. no learning after these many episodes ''' def __init__(self, action_func = None, num_training = 100, epsilon = 0.5, alpha = 0.5, gamma = 1): - # we should never be in this position, overwrite this later - if action_func is None: - action_func = lambda state: state.get_legal_actions() # not possible, state not an obj self.action_func = action_func self.episodes_so_far = 0 diff --git a/src/machine_learning/q_learning_agents.py b/src/machine_learning/q_learning_agents.py index 4d056f5..bf97853 100644 --- a/src/machine_learning/q_learning_agents.py +++ b/src/machine_learning/q_learning_agents.py @@ -56,7 +56,10 @@ def compute_action_from_Q_values(self, state): if len(legal_actions) == 0: return None - best = sorted([(a, self.get_Q_value(state, a)) for a in legal_actions], key = lambda x: x[1], reverse = True) + best = sorted( + [(a, self.get_Q_value(state, a)) for a in legal_actions], + key = lambda x: x[1], reverse = True + ) return random.choice([b[0] for b in best if b[1] == best[0][1]]) def get_action(self, state): @@ -75,8 +78,6 @@ def get_action(self, state): def update(self, state, action, next_state, reward): # from value estmation parent.parent NSQ = self.get_value(next_state) - - self.q_values[(state, action)] = self.get_Q_value(state, action) + self.alpha * (reward + self.discount*NSQ - self.get_Q_value(state, action)) From be27bf075061a596e06b935006e58cdcfef90ed8 Mon Sep 17 00:00:00 2001 From: Kaya Celebi Date: Sun, 8 Sep 2024 23:27:26 -0400 Subject: [PATCH 6/8] Added state module, adding train.py to run epiusodes --- nb/q-learning-v0.ipynb | 105 +++++++++++++++--------- src/machine_learning/learning_agents.py | 28 +++++-- src/machine_learning/state.py | 32 ++++++++ src/machine_learning/train.py | 24 ++++++ src/sim/sim.py | 2 +- 5 files changed, 144 insertions(+), 47 deletions(-) create mode 100644 src/machine_learning/state.py create mode 100644 src/machine_learning/train.py diff --git a/nb/q-learning-v0.ipynb b/nb/q-learning-v0.ipynb index 7e5b287..6a4bcea 100644 --- a/nb/q-learning-v0.ipynb +++ b/nb/q-learning-v0.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "8fda40c4", + "id": "84dfa7f6", "metadata": {}, "outputs": [], "source": [ @@ -23,41 +23,26 @@ }, { "cell_type": "code", - "execution_count": 9, - "id": "373b2a7b", - "metadata": {}, - "outputs": [], - "source": [ - "# we need to define an action function to obtain the legal actions, which would involve placing\n", - "# anywhere that is not alive on the board.\n", - "\n", - "# then, how to evaluate said action\n", - "\n", - "agent = QLearningAgent(action_func = ...)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "ada366a0", + "execution_count": 2, + "id": "b7a26397", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([[1, 1, 0, 1, 1, 0, 0, 0, 1, 1],\n", - " [1, 1, 0, 0, 1, 0, 1, 0, 1, 1],\n", - " [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],\n", - " [0, 0, 1, 1, 0, 1, 1, 1, 1, 1],\n", - " [0, 0, 0, 0, 0, 1, 1, 1, 0, 0],\n", - " [1, 0, 1, 1, 1, 1, 1, 1, 1, 0],\n", - " [0, 1, 0, 0, 0, 0, 1, 0, 0, 1],\n", - " [0, 1, 1, 0, 0, 1, 0, 0, 1, 1],\n", - " [0, 1, 1, 0, 0, 1, 1, 0, 1, 1],\n", - " [1, 1, 1, 0, 0, 0, 1, 1, 0, 0]])" + "array([[0, 0, 0, 1, 1, 1, 0, 0, 1, 0],\n", + " [1, 1, 1, 1, 0, 1, 0, 1, 0, 0],\n", + " [0, 0, 0, 0, 1, 1, 0, 1, 0, 1],\n", + " [0, 0, 0, 1, 0, 0, 1, 1, 0, 1],\n", + " [1, 1, 0, 0, 0, 0, 0, 1, 1, 1],\n", + " [1, 0, 0, 1, 1, 0, 1, 0, 0, 1],\n", + " [1, 1, 1, 0, 0, 0, 1, 0, 1, 0],\n", + " [1, 0, 1, 0, 1, 0, 0, 0, 1, 0],\n", + " [1, 0, 1, 0, 1, 0, 0, 0, 0, 1],\n", + " [0, 1, 1, 1, 0, 1, 1, 1, 1, 1]])" ] }, - "execution_count": 6, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -70,21 +55,40 @@ { "cell_type": "code", "execution_count": 8, - "id": "4afa74d9", + "id": "80a4c523", + "metadata": {}, + "outputs": [], + "source": [ + "# we need to define an action function to obtain the legal actions, which would involve placing\n", + "# anywhere that is not alive on the board.\n", + "\n", + "# then, how to evaluate said action\n", + "\n", + "action_func = lambda state: np.arange(state.shape[0]*state.shape[1])[state.flat == 0]\n", + "\n", + "agent = QLearningAgent(action_func = action_func)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "40145b18", "metadata": {}, "outputs": [ { - "ename": "AttributeError", - "evalue": "'numpy.ndarray' object has no attribute 'get_legal_actions'", + "ename": "TypeError", + "evalue": "unhashable type: 'numpy.ndarray'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/var/folders/cc/c6pxrx8n6wx77yf6_g8dvdw40000gn/T/ipykernel_14646/3427211211.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0magent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/q_learning_agents.py\u001b[0m in \u001b[0;36mget_action\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mlegal_actions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_legal_actions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;31m# Terminal state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/learning_agents.py\u001b[0m in \u001b[0;36mget_legal_actions\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_legal_actions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 69\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mobserve_transition\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdelta_reward\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/learning_agents.py\u001b[0m in \u001b[0;36m\u001b[0;34m(state)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;31m# we should never be in this position, overwrite this later\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0maction_func\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0maction_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_legal_actions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# not possible, state not an obj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0maction_func\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'numpy.ndarray' object has no attribute 'get_legal_actions'" + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/folders/cc/c6pxrx8n6wx77yf6_g8dvdw40000gn/T/ipykernel_15343/3427211211.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0magent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/q_learning_agents.py\u001b[0m in \u001b[0;36mget_action\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepsilon\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlegal_actions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_action_from_Q_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/q_learning_agents.py\u001b[0m in \u001b[0;36mcompute_action_from_Q_values\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \t\tbest = sorted(\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_Q_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlegal_actions\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreverse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \t\t)\n", + "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/q_learning_agents.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \t\tbest = sorted(\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_Q_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlegal_actions\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreverse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \t\t)\n", + "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/q_learning_agents.py\u001b[0m in \u001b[0;36mget_Q_value\u001b[0;34m(self, state, action)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mQ\u001b[0m \u001b[0mnode\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0motherwise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \t\t'''\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mq_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_value_from_Q_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/util.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetdefault\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: unhashable type: 'numpy.ndarray'" ] } ], @@ -92,10 +96,33 @@ "agent.get_action(state)" ] }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2b0e6324", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0, 1, 2, 6, 7, 9, 14, 16, 18, 19, 20, 21, 22, 23, 26, 28, 30,\n", + " 31, 32, 34, 35, 38, 42, 43, 44, 45, 46, 51, 52, 55, 57, 58, 63, 64,\n", + " 65, 67, 69, 71, 73, 75, 76, 77, 79, 81, 83, 85, 86, 87, 88, 90, 94])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.get_legal_actions(state)" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "d97459e4", + "id": "72ceab71", "metadata": {}, "outputs": [], "source": [] diff --git a/src/machine_learning/learning_agents.py b/src/machine_learning/learning_agents.py index f3ca6b5..88e31a7 100644 --- a/src/machine_learning/learning_agents.py +++ b/src/machine_learning/learning_agents.py @@ -74,7 +74,8 @@ def get_legal_actions(self, state): return self.action_func(state) ''' - Who calls this? + Called by observe_func after we have actually transitioned ot next state + -- then we have to record it ''' def observe_transition(self, state, action, next_state, delta_reward): self.episode_rewards += delta_reward @@ -87,15 +88,19 @@ def observe_transition(self, state, action, next_state, delta_reward): ''' def observe_function(self, state): - if ... : - reward = ... #current_reward - prev state reward :: delta_reward + # ensure we don't call at first episode + if self.last_state is not None: + delta_reward = self.reward_func(state) - self.reward_func(self.last_state) + self.observe_transition( state = self.last_state, action = self.last_action, next_state = state, - delta_reward = reward + delta_reward = delta_reward ) + return state + ''' Called by env when new episode is starting ''' @@ -104,6 +109,9 @@ def start_episode(self): self.last_action = None self.episode_rewards = 0.0 + ''' + Called by env at the end of an episode + ''' def stop_episode(self): if self.is_in_training(): self.accum_train_rewards += self.episode_rewards @@ -112,7 +120,7 @@ def stop_episode(self): self.episodes_so_far += 1 - # set vars for testing + # stop the learning for testing stage if self.is_in_testing: self.epsilon = 0.0 # no exploration self.alpha = 0.0 # no learning @@ -133,17 +141,23 @@ def is_in_testing(self): gamma - discount factor numTraining - number of training episodes, i.e. no learning after these many episodes ''' - def __init__(self, action_func = None, num_training = 100, epsilon = 0.5, alpha = 0.5, gamma = 1): + def __init__(self, action_func = None, reward_func = None, num_training = 100, epsilon = 0.5, alpha = 0.5, gamma = 1): self.action_func = action_func + self.reward_func = reward_func self.episodes_so_far = 0 self.accum_train_rewards = 0.0 - self.accum_train_rewards = 0.0 + self.accum_test_rewards = 0.0 self.num_training = int(num_training) self.epsilon = float(epsilon) self.alpha = float(alpha) self.discount = float(gamma) + + def final(self, state): + delta_reward = get_score() - self.last_state.get_score() + #... finsih implementing later + diff --git a/src/machine_learning/state.py b/src/machine_learning/state.py new file mode 100644 index 0000000..2950b5c --- /dev/null +++ b/src/machine_learning/state.py @@ -0,0 +1,32 @@ +import numpy as np + +import sys +sys.path.append('../') + +from path_handler import PathHandler as PH +import sim.automata as atm +from sim.rules import Rules + +class State: + + def __init__(self, state, step_num, rule): + self.values = state + self.step_num = step_num + self.rule = rule + + # fix this dogshit hash function + def __hash__(self): + return hash(self.values.tostring()) + + def __repr__(self): + return self.values.__repr__() + + ''' + Generate successor states based on given action + ''' + def get_successors(self, action): + new_state = self.values.copy() + new_state[action // self.values.shape[1], action % self.values.shape[1]] = 1 + new_state = atm.update(new_state, rule = self.rule) + + return State(new_state, step_num = self.step_num + 1, rule = self.rule) diff --git a/src/machine_learning/train.py b/src/machine_learning/train.py new file mode 100644 index 0000000..9302cbf --- /dev/null +++ b/src/machine_learning/train.py @@ -0,0 +1,24 @@ +import sys +sys.path.append('../') + +import sim.automata as atm +from sim.rules import Rules + +from machine_learning.q_learning_agents import QLearningAgent +from machine_learning.state import State + +def run(init_state, rule = Rules.CONWAY, **args): + state = State(init_state, step_num = 0, rule = rule) + + agent = QLearningAgent(action_func = ..., reward_func = ..., **args) + + while ...: + + # observe the current state + state = agent.observe_function(state) + + action = agent.get_action(state) + + state = state.get_successors(action) + + print(f"Rewards: {agent.accum_train_rewards}") diff --git a/src/sim/sim.py b/src/sim/sim.py index a6525a9..7c8d627 100644 --- a/src/sim/sim.py +++ b/src/sim/sim.py @@ -15,7 +15,7 @@ def play(state, steps, rule = Rules.CONWAY, verbose = False, verbose_func = disp i = 1 states = np.zeros((steps, *state.shape), dtype = int) states[0, :, :] = state - while i < steps and not is_terminal_state(state): + while i < steps and not atm.is_terminal_state(state): if verbose: verbose_func(state) From de0405271a10d8da44747ebde730e6db00a66587 Mon Sep 17 00:00:00 2001 From: Kaya Celebi Date: Sun, 8 Sep 2024 23:57:07 -0400 Subject: [PATCH 7/8] Reinforcement learning tests passewd --- nb/q-learning-v0.ipynb | 632 ++++++++++++++++++++++-- src/analysis/analysis.py | 5 +- src/machine_learning/learning_agents.py | 4 + src/machine_learning/state.py | 21 +- src/machine_learning/train.py | 6 +- 5 files changed, 615 insertions(+), 53 deletions(-) diff --git a/nb/q-learning-v0.ipynb b/nb/q-learning-v0.ipynb index 6a4bcea..040eb0a 100644 --- a/nb/q-learning-v0.ipynb +++ b/nb/q-learning-v0.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "84dfa7f6", + "id": "54987006", "metadata": {}, "outputs": [], "source": [ @@ -18,44 +18,40 @@ "import analysis.stats as stats\n", "from sim.rules import Rules\n", "\n", - "from machine_learning.q_learning_agents import QLearningAgent" + "from machine_learning.q_learning_agents import QLearningAgent\n", + "from machine_learning.state import State" ] }, { "cell_type": "code", - "execution_count": 2, - "id": "b7a26397", + "execution_count": 7, + "id": "f80ef339", "metadata": {}, "outputs": [ { "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAMW0lEQVR4nO2dMW4WTRKG26vNYEWykgUih9x3+k/glBtwgj0Bt+ACOIccgZA2hHg2+NeZu/tz1bxT/c73PJKj9kx19cwr2fOqqm62bWsAsD7/qN4AAFwGYgUwAbECmIBYAUxArAAmIFYAE/75nF++ubnp+jzv379vL168eHLtz58/w7WvX792Y2buG1lT3XeW5+vXr9vPnz+fXJudQe++o3vO1jNrb9686cb88eNH6FrV2WZiZt6h0X63bbvpLVz801rbej+fP3/eeszWVPeNrKnuO8vz48eP4TOI3HO2nlkbEb1WdbaZmCMy+906+uPPYAATECuACYgVwATECmDCs8R6d3fX/fj08PDQbm5unvx5eHhQ7X+IYj+zPEdrvbPbksUUvXuOntdsfba2Gpk8o+9J77rMuzA625vZi3Jzc/NXa+2v1lq7vb29+/Tp05O/9+vXr/b9+/cn196+fdtub2+fXPv9+3f79u1bN/67d+/ay5cvu9eO1nr3He1ndt9ZnpEzmN03cwa9tcy1mbPNvCeqPCO5zN7b6Ltwf3/fvnz5krdu7u7uup+jM5/He9c1kSU0sxdG962wNFa0qBR5rmajZN7baJ7/1xjWDYAziBXABMQKYAJiBTABsQKYYG3dZGyU3lrm2qj10Joml4xFtZqlobLiFGuXXNvL5f7+fp+qm9WsG0XVSObaqPWgyiVjUa1maaisuCpbbHRGG9YNgDeIFcAExApgAmIFMAGxApiwm1hVZUgqent93G9vLcOopGpEplzt6DKuGSuWqyliKtjNZ814diPPU+WzqkrkFP7jrEQuerYKzzhb8hjJpco3V+R5iM+q6sBXUa6WiRk5g9l+q7obRs4nW/IY3WsmT8XzxGcFuGIQK4AJiBXABMQKYMJhYu390zzrwKeKGSVjPYzyVHWOzJytwu6YUfHM9t7rJWcb2c9h1o2i1EjRKS+z3zN1VMzYHVErrsL+y5QCKt6hQ7obVpQaqcrV6KioGVrVLrA0omebibnSO0R3Q4ATgFgBTECsACYgVgATniXWTPVC9FP/6FrVJ/lRzAyqKg3FGVTFPLoKSLFXlR6eZd28evXq7sOHD0/+nqoDX8XApkzMiuqiio6KR1ejuMWM6mG3qpt2cGXDttUMbKqogKmwqJyqUdxiZuytDesGwBvECmACYgUwAbECmIBYAUxgMNWO115Tw7TVKmBUw7CijeHKrRsGU+ksDbeGadGYK9ooo/uqYo6u3bBuALxBrAAmIFYAExArgAmIFcCE5QdTqcrVenvdJlbWigO4RvuJ5tlavFwtE3O1s42iKN9cfjCVosuesixPkWdFKaDCm8zs91oGcFkPplJ14BuxmudZUQro1jky8w7tvZ/Me9LwWQH8QawAJiBWABMQK4AJu3U3zND7h3qbWAGZtWinwVlnush+ZuuqQVCK8xvtNfuenInR2fbYrbuhovTp8dq9bQtVeZPK0nAbwLXaYKoVrRv5YKom+FSdGdiksAEe9xRdi+xntn4mGwXrhsFUAKcHsQKYgFgBTECsACY8S6yjz/kzetfNbBTVwCtVnlEytpiiokllUanIxDz6/YpyWNXNajaKqqNixZAoRXWRW9WNKs9oh8xy6yZTdTNaawU2yiim25CoyH6yMaN5qqwbVZ6Rtdk7hHUDcHIQK4AJiBXABMQKYAJiBTBh+cFUFY3EVqv0ccvT6WxVtlg0l0OsmxUHNkXWMteqLCq3PCtiRs9WZYth3QBcMYgVwATECmACYgUwAbECmLBbd8PMIJ7edY/X9tZUZErHoiVVmTwrSuQi+1GW+10Dh3U3XM2XqyhXq+j6V1EipxjYpPLqT+uzNkFJ0IpeoFNJVVXpWGQ/l+SpiInPCgCHglgBTECsACYgVgATDhNrhb0QWbvEhoquRXEbTNW7bhPabao8R2QsqgjLD6Zys1EUg4wUFssl+60YNKawbiqGYZV3N2wHf5LfNj8bJXIGs5ir2QuZ51lh3ajyVMTEugE4AYgVwATECmACYgUw4TCx9v5pnn0er7BRVEOievvZti0VM5pnhlHMjKWhyCVjfVWcbY+rrbpRdRpcrQJmtbOtsMUU56fKcwnrJrKWubaq0+DoniPchkSN1o5+T7JVNyvliXUDcAIQK4AJiBXABMQKYAJiBTDhsMFUigoYVUynBlsVlSFKS2O1qpuj87y/v2/bttUOpoqstckncFVMpwZbFZUhSksjcn4r2kWZazesGwBvECuACYgVwATECmACYgUwYfnBVG6lYyNGpWOqIVGK/UbLymbPc1RGqHqHFOVz28QOHV07yvPU3Q1XG2RU4e1mfNao55kZ+hUtV1MNGju6/HCJErnedSsOT6roqLhiF8fo2WaeZ+S+2ZiKPCmRA7hiECuACYgVwATECmDCqcUa/SSv6qjYW5utzxjds4LM0K/IfWeDxmYozk9hF53aulltMFX0U3/G7jhTuVq0FFDV3VBhF+1WItfMrJvI2iW57L02W1dZVG7latHn6WQX/S1JrBsAaxArgAmIFcAExApgAmIFMGGJ7oYV1k00ZoVFpbA7Hve0tyWUfZ7RmKqKpqOri3arulF1N1R8AlfFrLCoVhzYpHqe0ZiZPKP7UeRJ1Q3ACUCsACYgVgATECuACYgVwITDqm4UA35WHEyl+tS/WnWRahiWokmbolLq1A3TRmvR+67YvCxyBqpclNVFqpiKs12tioqGaQAnB7ECmIBYAUxArAAmIFYAE54l1kwHviiZYViRtdn6jIoBUqM8MigGSGUGU0W7G2Y6DUafp2KY2G4lchm/KlrGpehCOFtfsRTQqbuhqiyvYuiXohTwkBK5ijKuFTsNjtZGVAzgquhu6FYip3qePfBZAU4AYgUwAbECmIBYAUx4llhVn7l7/1BvQhslM8goksvM0sjYUBV2UcR+uSSXCvY+90tsxdLBVKrSMYUllCnjynTnc8ozet+ZpVGRp2owlcIWO2Qwlap0bLUyrmvJM3rfmaVxpsFUirP9W5JYNwDWIFYAExArgAmIFcAExApgwvLWzWpDolQVMKrKEFV3w6O7/mViOnWOPKS7YUXXv4ohUaoKGLfuhpE11X3P1DmSqhuAE4BYAUxArAAmIFYAExArgAmHNUxbzUZRNUw7S8yKBnhVDdOiMU/bMK0tZqOMWK3qZsVKn9Fab6/ZZxa5pzLmCBqmAVwxiBXABMQKYAJiBTABsQKYsFt3wxm9L1yZLnEzKrobRvYzi1nV3fDobpWz7obR6zJEYy7d3TBTbhXtEqcofcrGVA2JUnjYivuezdtVlQKWdjdU+XKqIVGqmIo8VR62ajBVJubee73kXdh7P9n9bvisAN4gVgATECuACYgVwIRniXX0OToztKp3z9l9FTbAdoElFCWaZ4aRXZQZJlYRU2VRKWIqzna3EjlFKdvsvqruhhXlaqo8M6VjFV0cr6VErre2m3UzKpFTlLLN7lthabh1ceytXfJconlmYlbkGY0ZfZ6z/W5YNwDeIFYAExArgAmIFcAExApgwvKDqTI2StQucrNuFBVNqqobpyFRGbvoKgdTqapuVDGjeWasm9Ha6JlVVN04DYnK2EUMpgK4YhArgAmIFcAExApgAmIFMOEw62a1qhtVwzQne6GiAkZlF1U0TFPEXMK6Ga2NqKhGWa1JW1XVzdFnm7GLVHkeHRPrBuAEIFYAExArgAmIFcAExApgwm7dDTOdBqMDmzId5Hp73YSdBqO5zIg8k+wzUxHdT0Weo5gz5IOpRt0Nq7zAo71dVRnXaD3jwaqGRB1dOlaVZzTm0t0NV+zAN1obUVHGpSqRG+FUOlaVZzQm3Q0BrhjECmACYgUwAbECmPAssa5mL1SQsa9UeUaGQF0yJCoyTGxmhWRiVgymUhE5v91K5CrshRWHYTnZKCsOpjq6W2WVdVNaIlfx2X3FsjxFnivaYpH9ZGNGn+eK1k0PSuQATgBiBTABsQKYgFgBTECsACZYD6ZS2B2P+12po2JFnlVd/44e+pXp4pip3Fq6u2H0E7jqs3s0F7dqlKOfmdK6iexn23LDxCLXZWJi3QCcAMQKYAJiBTABsQKYgFgBTHiWddNae9da632X/3dr7b87r6nuS0xirhrz3bZt/3pypfeZ+Lk/rbUve6+p7ktMYjrG5M9gABMQK4AJe4r1P4I11X2JSUy7mNMPTACwBvwZDGACYgUwAbECmIBYAUxArAAm/A8vcPUave7SMwAAAABJRU5ErkJggg==\n", "text/plain": [ - "array([[0, 0, 0, 1, 1, 1, 0, 0, 1, 0],\n", - " [1, 1, 1, 1, 0, 1, 0, 1, 0, 0],\n", - " [0, 0, 0, 0, 1, 1, 0, 1, 0, 1],\n", - " [0, 0, 0, 1, 0, 0, 1, 1, 0, 1],\n", - " [1, 1, 0, 0, 0, 0, 0, 1, 1, 1],\n", - " [1, 0, 0, 1, 1, 0, 1, 0, 0, 1],\n", - " [1, 1, 1, 0, 0, 0, 1, 0, 1, 0],\n", - " [1, 0, 1, 0, 1, 0, 0, 0, 1, 0],\n", - " [1, 0, 1, 0, 1, 0, 0, 0, 0, 1],\n", - " [0, 1, 1, 1, 0, 1, 1, 1, 1, 1]])" + "
" ] }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" } ], "source": [ - "state = atm.get_random_state((10,10))\n", - "state" + "np.random.seed(0)\n", + "init_state = State(atm.get_random_state((30, 30)), Rules.CONWAY)\n", + "#print(init_state)\n", + "ans.plot_state(init_state.values)" ] }, { "cell_type": "code", "execution_count": 8, - "id": "80a4c523", + "id": "30a5df76", "metadata": {}, "outputs": [], "source": [ @@ -65,64 +61,608 @@ "# then, how to evaluate said action\n", "\n", "action_func = lambda state: np.arange(state.shape[0]*state.shape[1])[state.flat == 0]\n", - "\n", - "agent = QLearningAgent(action_func = action_func)" + "reward_func = lambda state: (state.values == 1).sum()" ] }, { "cell_type": "code", - "execution_count": 11, - "id": "40145b18", + "execution_count": null, + "id": "2f56a560", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "69b95a7e", + "metadata": {}, + "outputs": [], + "source": [ + "def run(init_state, action_func, reward_func, episode_length = 10):\n", + " agent = QLearningAgent(action_func = action_func, reward_func = reward_func)\n", + " \n", + " state = init_state.copy()\n", + " while agent.episodes_so_far < agent.num_training + 50:\n", + " \n", + " agent.start_episode()\n", + " for i in range(episode_length):\n", + " _ = agent.observe_function(state)\n", + " \n", + " action = agent.get_action(state)\n", + " agent.do_action(state, action)\n", + " \n", + " state = state.get_successor(action)\n", + " \n", + " agent.stop_episode()\n", + " print(f\"Train Rewards: {agent.accum_train_rewards}\")\n", + " print(f\"Test Rewards: {agent.accum_test_rewards}\\n\")\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a4147c0f", "metadata": {}, "outputs": [ { - "ename": "TypeError", - "evalue": "unhashable type: 'numpy.ndarray'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/var/folders/cc/c6pxrx8n6wx77yf6_g8dvdw40000gn/T/ipykernel_15343/3427211211.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0magent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/q_learning_agents.py\u001b[0m in \u001b[0;36mget_action\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepsilon\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlegal_actions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_action_from_Q_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/q_learning_agents.py\u001b[0m in \u001b[0;36mcompute_action_from_Q_values\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \t\tbest = sorted(\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_Q_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlegal_actions\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreverse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \t\t)\n", - "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/q_learning_agents.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \t\tbest = sorted(\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_Q_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlegal_actions\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreverse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \t\t)\n", - "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/q_learning_agents.py\u001b[0m in \u001b[0;36mget_Q_value\u001b[0;34m(self, state, action)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mQ\u001b[0m \u001b[0mnode\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0motherwise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \t\t'''\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mq_values\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_value_from_Q_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/research/cellular-automata/src/machine_learning/util.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetdefault\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mTypeError\u001b[0m: unhashable type: 'numpy.ndarray'" + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Rewards: 0.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -284.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -328.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -328.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -308.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -315.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -330.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -328.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -298.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -296.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -322.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -338.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -324.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -366.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -355.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -328.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -323.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -300.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -326.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -302.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -290.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -277.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -253.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -242.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -236.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -301.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -304.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -300.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -280.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -284.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -318.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -302.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -287.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -260.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -285.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -260.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -222.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -228.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -257.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -264.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -255.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -233.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -256.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -248.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -264.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -276.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -282.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -263.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -229.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -251.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -283.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -265.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -276.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -281.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -264.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -222.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -232.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -239.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -259.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -261.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -278.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -268.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -253.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -257.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -262.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -271.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -255.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -230.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -235.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -256.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -232.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -237.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -228.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -243.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -240.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -240.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -240.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -240.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -240.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -238.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -236.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -241.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -231.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -227.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -205.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -202.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -199.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -195.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -195.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -198.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -203.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -216.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -213.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -223.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -222.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -223.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -223.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -223.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -223.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 0.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 1.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 33.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 27.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 37.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 48.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 29.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 41.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 41.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 52.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 12.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 38.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 20.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 15.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 4.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 4.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 4.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 4.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 4.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 4.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 4.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 8.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 8.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 8.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 8.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 8.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 25.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 6.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 25.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 20.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 30.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 8.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 11.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 23.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 16.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 18.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 23.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 25.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: 19.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: -4.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: -22.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: -22.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: -22.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: -22.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: -22.0\n", + "\n", + "Train Rewards: -225.0\n", + "Test Rewards: -22.0\n", + "\n" ] } ], "source": [ - "agent.get_action(state)" + "run(init_state, action_func, reward_func)" ] }, { "cell_type": "code", - "execution_count": 10, - "id": "2b0e6324", + "execution_count": 25, + "id": "1e634653", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1, 0, 0, 1],\n", + " [0, 0, 0, 1],\n", + " [0, 0, 0, 1],\n", + " [1, 1, 1, 0]])" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.observe_function(state)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "1f686987", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([ 0, 1, 2, 6, 7, 9, 14, 16, 18, 19, 20, 21, 22, 23, 26, 28, 30,\n", - " 31, 32, 34, 35, 38, 42, 43, 44, 45, 46, 51, 52, 55, 57, 58, 63, 64,\n", - " 65, 67, 69, 71, 73, 75, 76, 77, 79, 81, 83, 85, 86, 87, 88, 90, 94])" + "0" ] }, - "execution_count": 10, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "agent.get_legal_actions(state)" + "action = agent.get_action(state)\n", + "action" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "d0093739", + "metadata": {}, + "outputs": [], + "source": [ + "agent.last_state" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7948263e", + "metadata": {}, + "outputs": [], + "source": [ + "agent.do_action(state, action)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "d9b8a338", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAEvUlEQVR4nO3XMWobaxSG4fNf0kXgxhACLtwp/dwNZDNagVrvQNnAXYF3oSxAWcDtUoiEgEv35xY3pZ3YIGnyDc8DfyFminOQXjQzuruAP99fcw8AvIxYIYRYIYRYIYRYIYRYIcSb19x8fX3dt7e3ZxplPt++favv37/PPcbJvX//fpF7VVV9+PCh3r59O/cYJ/f169d6eHgYT17s7hefaZp6iXa7XVfV4s5S96qq3u/3c/9szuJnY0/25zEYQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQrwq1i9fvtQYY3EHEozu/vUNY2yqalNVdXV1Nd3d3V1irou6ubmp4/E49xgnt9S9qqrW63WtVqu5xzi57XZbh8Ph6X+Q7n7xqape4tntdrPPYK/Xnf1+30s0TVP3M/15Z4UQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQo7t/fcMYm6raVFVdXV1Nd3d3l5jrom5ubup4PM49xsktda+qqvV6XavVau4xTm673dbhcBhPXuzuF5+q6iWe3W43+wz2et3Z7/e9RNM0dT/Tn8dgCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCPHmNTdP01SHw+Fcs8zm8+fP1d1zj3FyS92rqurTp0/18ePHuce4qPG7L3OMsamqTVXVu3fvpvv7+0vMdVGPj4+1Wq3mHuPklrpXVdWPHz/qeDzOPcbJbbfb6u7x5MXufvGZpqmXaL/fzz3CWSx1r+7u3W7XVbXI08/0550VQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQrz53Q1jjE1VbX5+fBxj/HvekWZxXVUPcw9xBkvdq2q5u62fuzC6+5KD/JHGGIfu/nvuOU5tqXtVLXe3X+3lMRhCiBVCiPV//8w9wJksda+q5e727F7eWSGEf1YIIVYIIVYIIVYIIVYI8R9a6B1zOUX4ggAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "state = state.get_successor(action)\n", + "ans.plot_state(state.values)" ] }, { "cell_type": "code", "execution_count": null, - "id": "72ceab71", + "id": "63dd58d7", "metadata": {}, "outputs": [], "source": [] diff --git a/src/analysis/analysis.py b/src/analysis/analysis.py index 45a2237..d058b81 100644 --- a/src/analysis/analysis.py +++ b/src/analysis/analysis.py @@ -1,4 +1,5 @@ import numpy as np +import matplotlib.pyplot as plt import sys sys.path.append('../') @@ -35,9 +36,9 @@ def plot_state(state): fig, axs = plt.subplots() if state.shape[0] > 1: - axs.imshow(~state[0], cmap = 'gray') + axs.imshow(state, cmap = 'gray') else: - axs.imshow(~state, cmap = 'gray') + axs.imshow(state[0], cmap = 'gray') axs.set_xticks(np.arange(len(state))+0.5) axs.set_yticks(np.arange(len(state))+0.5) diff --git a/src/machine_learning/learning_agents.py b/src/machine_learning/learning_agents.py index 88e31a7..74a2fd1 100644 --- a/src/machine_learning/learning_agents.py +++ b/src/machine_learning/learning_agents.py @@ -101,6 +101,10 @@ def observe_function(self, state): return state + def do_action(self, state, action): + self.last_state = state + self.last_action = action + ''' Called by env when new episode is starting ''' diff --git a/src/machine_learning/state.py b/src/machine_learning/state.py index 2950b5c..0134122 100644 --- a/src/machine_learning/state.py +++ b/src/machine_learning/state.py @@ -9,11 +9,12 @@ class State: - def __init__(self, state, step_num, rule): + def __init__(self, state, rule): self.values = state - self.step_num = step_num self.rule = rule + self._shape = self.values.shape + # fix this dogshit hash function def __hash__(self): return hash(self.values.tostring()) @@ -21,12 +22,24 @@ def __hash__(self): def __repr__(self): return self.values.__repr__() + def copy(self): + return State(self.values, self.rule) + + @property + def shape(self): + return self.values.shape + + @property + def flat(self): + return self.values.flat + + ''' Generate successor states based on given action ''' - def get_successors(self, action): + def get_successor(self, action): new_state = self.values.copy() new_state[action // self.values.shape[1], action % self.values.shape[1]] = 1 new_state = atm.update(new_state, rule = self.rule) - return State(new_state, step_num = self.step_num + 1, rule = self.rule) + return State(new_state, rule = self.rule) diff --git a/src/machine_learning/train.py b/src/machine_learning/train.py index 9302cbf..ad0ec9a 100644 --- a/src/machine_learning/train.py +++ b/src/machine_learning/train.py @@ -19,6 +19,10 @@ def run(init_state, rule = Rules.CONWAY, **args): action = agent.get_action(state) - state = state.get_successors(action) + agent.do_action(state, action) + + state = state.get_successor(action) + + print(f"Rewards: {agent.accum_train_rewards}") From 456fe7154a0b933e050a306eccfaf288f27bb9fe Mon Sep 17 00:00:00 2001 From: Kaya Celebi Date: Mon, 9 Sep 2024 00:24:27 -0400 Subject: [PATCH 8/8] Reinforcement Learning modules all collected --- nb/q-learning-v0.ipynb | 599 ++++------------------------------ requirements.txt | 3 +- src/machine_learning/train.py | 42 +-- 3 files changed, 86 insertions(+), 558 deletions(-) diff --git a/nb/q-learning-v0.ipynb b/nb/q-learning-v0.ipynb index 040eb0a..624bf1b 100644 --- a/nb/q-learning-v0.ipynb +++ b/nb/q-learning-v0.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "54987006", + "id": "dfa06e9c", "metadata": {}, "outputs": [], "source": [ @@ -24,13 +24,13 @@ }, { "cell_type": "code", - "execution_count": 7, - "id": "f80ef339", + "execution_count": 60, + "id": "819e4f6a", "metadata": {}, "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAMW0lEQVR4nO2dMW4WTRKG26vNYEWykgUih9x3+k/glBtwgj0Bt+ACOIccgZA2hHg2+NeZu/tz1bxT/c73PJKj9kx19cwr2fOqqm62bWsAsD7/qN4AAFwGYgUwAbECmIBYAUxArAAmIFYAE/75nF++ubnp+jzv379vL168eHLtz58/w7WvX792Y2buG1lT3XeW5+vXr9vPnz+fXJudQe++o3vO1jNrb9686cb88eNH6FrV2WZiZt6h0X63bbvpLVz801rbej+fP3/eeszWVPeNrKnuO8vz48eP4TOI3HO2nlkbEb1WdbaZmCMy+906+uPPYAATECuACYgVwATECmDCs8R6d3fX/fj08PDQbm5unvx5eHhQ7X+IYj+zPEdrvbPbksUUvXuOntdsfba2Gpk8o+9J77rMuzA625vZi3Jzc/NXa+2v1lq7vb29+/Tp05O/9+vXr/b9+/cn196+fdtub2+fXPv9+3f79u1bN/67d+/ay5cvu9eO1nr3He1ndt9ZnpEzmN03cwa9tcy1mbPNvCeqPCO5zN7b6Ltwf3/fvnz5krdu7u7uup+jM5/He9c1kSU0sxdG962wNFa0qBR5rmajZN7baJ7/1xjWDYAziBXABMQKYAJiBTABsQKYYG3dZGyU3lrm2qj10Joml4xFtZqlobLiFGuXXNvL5f7+fp+qm9WsG0XVSObaqPWgyiVjUa1maaisuCpbbHRGG9YNgDeIFcAExApgAmIFMAGxApiwm1hVZUgqent93G9vLcOopGpEplzt6DKuGSuWqyliKtjNZ814diPPU+WzqkrkFP7jrEQuerYKzzhb8hjJpco3V+R5iM+q6sBXUa6WiRk5g9l+q7obRs4nW/IY3WsmT8XzxGcFuGIQK4AJiBXABMQKYMJhYu390zzrwKeKGSVjPYzyVHWOzJytwu6YUfHM9t7rJWcb2c9h1o2i1EjRKS+z3zN1VMzYHVErrsL+y5QCKt6hQ7obVpQaqcrV6KioGVrVLrA0omebibnSO0R3Q4ATgFgBTECsACYgVgATniXWTPVC9FP/6FrVJ/lRzAyqKg3FGVTFPLoKSLFXlR6eZd28evXq7sOHD0/+nqoDX8XApkzMiuqiio6KR1ejuMWM6mG3qpt2cGXDttUMbKqogKmwqJyqUdxiZuytDesGwBvECmACYgUwAbECmIBYAUxgMNWO115Tw7TVKmBUw7CijeHKrRsGU+ksDbeGadGYK9ooo/uqYo6u3bBuALxBrAAmIFYAExArgAmIFcCE5QdTqcrVenvdJlbWigO4RvuJ5tlavFwtE3O1s42iKN9cfjCVosuesixPkWdFKaDCm8zs91oGcFkPplJ14BuxmudZUQro1jky8w7tvZ/Me9LwWQH8QawAJiBWABMQK4AJu3U3zND7h3qbWAGZtWinwVlnush+ZuuqQVCK8xvtNfuenInR2fbYrbuhovTp8dq9bQtVeZPK0nAbwLXaYKoVrRv5YKom+FSdGdiksAEe9xRdi+xntn4mGwXrhsFUAKcHsQKYgFgBTECsACY8S6yjz/kzetfNbBTVwCtVnlEytpiiokllUanIxDz6/YpyWNXNajaKqqNixZAoRXWRW9WNKs9oh8xy6yZTdTNaawU2yiim25CoyH6yMaN5qqwbVZ6Rtdk7hHUDcHIQK4AJiBXABMQKYAJiBTBh+cFUFY3EVqv0ccvT6WxVtlg0l0OsmxUHNkXWMteqLCq3PCtiRs9WZYth3QBcMYgVwATECmACYgUwAbECmLBbd8PMIJ7edY/X9tZUZErHoiVVmTwrSuQi+1GW+10Dh3U3XM2XqyhXq+j6V1EipxjYpPLqT+uzNkFJ0IpeoFNJVVXpWGQ/l+SpiInPCgCHglgBTECsACYgVgATDhNrhb0QWbvEhoquRXEbTNW7bhPabao8R2QsqgjLD6Zys1EUg4wUFssl+60YNKawbiqGYZV3N2wHf5LfNj8bJXIGs5ir2QuZ51lh3ajyVMTEugE4AYgVwATECmACYgUw4TCx9v5pnn0er7BRVEOievvZti0VM5pnhlHMjKWhyCVjfVWcbY+rrbpRdRpcrQJmtbOtsMUU56fKcwnrJrKWubaq0+DoniPchkSN1o5+T7JVNyvliXUDcAIQK4AJiBXABMQKYAJiBTDhsMFUigoYVUynBlsVlSFKS2O1qpuj87y/v2/bttUOpoqstckncFVMpwZbFZUhSksjcn4r2kWZazesGwBvECuACYgVwATECmACYgUwYfnBVG6lYyNGpWOqIVGK/UbLymbPc1RGqHqHFOVz28QOHV07yvPU3Q1XG2RU4e1mfNao55kZ+hUtV1MNGju6/HCJErnedSsOT6roqLhiF8fo2WaeZ+S+2ZiKPCmRA7hiECuACYgVwATECmDCqcUa/SSv6qjYW5utzxjds4LM0K/IfWeDxmYozk9hF53aulltMFX0U3/G7jhTuVq0FFDV3VBhF+1WItfMrJvI2iW57L02W1dZVG7latHn6WQX/S1JrBsAaxArgAmIFcAExApgAmIFMGGJ7oYV1k00ZoVFpbA7Hve0tyWUfZ7RmKqKpqOri3arulF1N1R8AlfFrLCoVhzYpHqe0ZiZPKP7UeRJ1Q3ACUCsACYgVgATECuACYgVwITDqm4UA35WHEyl+tS/WnWRahiWokmbolLq1A3TRmvR+67YvCxyBqpclNVFqpiKs12tioqGaQAnB7ECmIBYAUxArAAmIFYAE54l1kwHviiZYViRtdn6jIoBUqM8MigGSGUGU0W7G2Y6DUafp2KY2G4lchm/KlrGpehCOFtfsRTQqbuhqiyvYuiXohTwkBK5ijKuFTsNjtZGVAzgquhu6FYip3qePfBZAU4AYgUwAbECmIBYAUx4llhVn7l7/1BvQhslM8goksvM0sjYUBV2UcR+uSSXCvY+90tsxdLBVKrSMYUllCnjynTnc8ozet+ZpVGRp2owlcIWO2Qwlap0bLUyrmvJM3rfmaVxpsFUirP9W5JYNwDWIFYAExArgAmIFcAExApgwvLWzWpDolQVMKrKEFV3w6O7/mViOnWOPKS7YUXXv4ohUaoKGLfuhpE11X3P1DmSqhuAE4BYAUxArAAmIFYAExArgAmHNUxbzUZRNUw7S8yKBnhVDdOiMU/bMK0tZqOMWK3qZsVKn9Fab6/ZZxa5pzLmCBqmAVwxiBXABMQKYAJiBTABsQKYsFt3wxm9L1yZLnEzKrobRvYzi1nV3fDobpWz7obR6zJEYy7d3TBTbhXtEqcofcrGVA2JUnjYivuezdtVlQKWdjdU+XKqIVGqmIo8VR62ajBVJubee73kXdh7P9n9bvisAN4gVgATECuACYgVwIRniXX0OToztKp3z9l9FTbAdoElFCWaZ4aRXZQZJlYRU2VRKWIqzna3EjlFKdvsvqruhhXlaqo8M6VjFV0cr6VErre2m3UzKpFTlLLN7lthabh1ceytXfJconlmYlbkGY0ZfZ6z/W5YNwDeIFYAExArgAmIFcAExApgwvKDqTI2StQucrNuFBVNqqobpyFRGbvoKgdTqapuVDGjeWasm9Ha6JlVVN04DYnK2EUMpgK4YhArgAmIFcAExApgAmIFMOEw62a1qhtVwzQne6GiAkZlF1U0TFPEXMK6Ga2NqKhGWa1JW1XVzdFnm7GLVHkeHRPrBuAEIFYAExArgAmIFcAExApgwm7dDTOdBqMDmzId5Hp73YSdBqO5zIg8k+wzUxHdT0Weo5gz5IOpRt0Nq7zAo71dVRnXaD3jwaqGRB1dOlaVZzTm0t0NV+zAN1obUVHGpSqRG+FUOlaVZzQm3Q0BrhjECmACYgUwAbECmPAssa5mL1SQsa9UeUaGQF0yJCoyTGxmhWRiVgymUhE5v91K5CrshRWHYTnZKCsOpjq6W2WVdVNaIlfx2X3FsjxFnivaYpH9ZGNGn+eK1k0PSuQATgBiBTABsQKYgFgBTECsACZYD6ZS2B2P+12po2JFnlVd/44e+pXp4pip3Fq6u2H0E7jqs3s0F7dqlKOfmdK6iexn23LDxCLXZWJi3QCcAMQKYAJiBTABsQKYgFgBTHiWddNae9da632X/3dr7b87r6nuS0xirhrz3bZt/3pypfeZ+Lk/rbUve6+p7ktMYjrG5M9gABMQK4AJe4r1P4I11X2JSUy7mNMPTACwBvwZDGACYgUwAbECmIBYAUxArAAm/A8vcPUave7SMwAAAABJRU5ErkJggg==\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWkUlEQVR4nO1dS44lRRKMGs0ORmxGaoHYN/u6EyfoLTfgBHMCbsEFuvewY4FASL2EdbLoofUqiJfuYWkWn4eZVIt8lR7uEZlR9czDwuPpOI5iGMb6+NfsAAzDyMGT1TA2gSerYWwCT1bD2ASerIaxCTxZDWMT/Lvn5qenpxfrPF999VX55JNPPl7/8ccfL65LKeWXX34pv/76a5dN/Vl0nb2njuXzzz//2/UXX3zRbXN7nbkn6+f2HlZ/kLGNYsvG/8MPP7ywid4Flk00Lqgf5L1txXYby08//VTev3//VFo4jiP9U0o5bn++//774xb19XEcx7ffftttE92D2LRiaV0jNvW4sPwo+oOMLSv+3veHZaOKrQYa2y2en5+P487889dgw9gE8sn6/Pxc/3f+G56enl78RPe8e/cuZdMb2/Pz89/azcSPIPLz/Pwc2ihieffuHS3+ut36P0UGiE2NzNjWqN+x1rjUY5CxqfvTGqd7eIoG4Onp6etSytellPLZZ589f/PNNx9/9/r16/Lpp59+vP79999fXLc+a13/+OOPL2xa7d7e8+WXX5aff/45tKlj+e23317YRX5a99RttGKpP4uuM/G3Ysv4efXq1QubqN2WTabPmWeWiWXU+4SM7W38qE2rP7ftvHnzphzHMYezsnhJOeEg2Vgy/Dlql8VZe+PPcrH6OvM8Ihs0fiSWs+usjWpsGTat/tTtHOashrE3PFkNYxN0TdaaqLcIdZQEyZDw1j23fus47sUSAYmlhTqWaNwyCY6Mn1a79XUmKRXFhsYfxcJK4iAJMhUY43QPXQmmV69ePX/33Xcff4ckHtAETUTUW+32JkrQZFHkB0lkzUy2ZPyslMQZlbzLxHb2nmbG9s2bN+Xt27fXE0z/X7D9CCTxgCZoIqK+ssCBlQSpoUq2ZPwg7SLPGXkeM2O7OrYWRRjGA6Brsmb4G4L6LwgiEED4WsYPixciQLg9w0+GJ9bPDIklI6RQIfKDxqYQUny0PS6IIhC+kOGWCiFFKz4Wl1T5QcQKEa9CBCcITx8ZC0OwwbBhcHCZKILFWaPv8QwhRSu+UVyMEb9SPB89w5FCfiSW3ncw40cxTpl35cOUNGc1jK1xaZ21hfqvAYvnXl2bPRp8E13z7eU7LM6X4ek1ED/RuGW5vWI9vYUotoyfDKLY1Bycvs6KrG0iYnOEV9XtjhLlq9aAkTVT1vPoXYsdtZ7OWjNVrEdnxoDGWTPrrDVUYvMSfPfPtDtKlK9aA0bWTBUb1u/dc/WZIevprDVTxXp0ZgyKOath7A9PVsPYBJcmKyLwZonCZyESSRyNJA5LfBH5yWysUFVNUAERnNTjxNpIgYhHmLFdSjAhCQ1VdYlRCSZWRQqkz0iyiFHRARWyjBARZKsxjKhWwqjcIRPyIwkNlKjX16UzidBqB0kwzexz5EdV0WFWUhBNFmVii/yoBBtRbBbyG8YD4NJkVW34zqD+q5P57o9sJEeqG6oqIjKACARYm9qRZ5axQbhkJFLJoBUbAomQf6Yo4irnWymWmZyVxflGbZ5XiS8UVSCjWDI2DyGKOLvO3rNKLDM5K4vz1e3OHCekz8hzz9ggfup4D3NWw9gbnqyGsQmW4KyzeNWo+NF11quV8pUb+xkCe1UhMyS2mXz6NpahBdNqrMyrRsXP4mJIbBk/veOW5c9X22Wts6rWZjM2mXflFl5nNYwHgCerYWyCS9UNWRUF63/3LLF53W4mFkb8KiBCBEQgUIO1QaAeJ0SsoKh8qawcGY3B1CMfkUp/KoEAI7kyU8g/Sgh/tVpGpo8jRSqjKkcykoQyUURJJAjqe2aK2pF2GfGrEkxRbKgQHvGTeRd6x2lkgizyo0hyWhRhGP8QyCcr4+S5UuKNv5ENytdqqLgMUukP4WLIJnd08zZjnDLPsEYdfwtIfzKxKSHnrMgpcoqNv5l7Rp4ip1q4RyrNX+XT99pliDpUghmFyCO6Z3nOmhEi9LaL8ioFX8vwHZaooAYyTorF/uwzy/QZeR5I/JENmg/oHQNzVsN4QHiyGsYmoBdMQ9b1FOuhLOE7oygZumFglhAeOaFvVoEx1Zopq/J/LweXbT5Hhdf1det7PMNPDQb/RDg36ifqs2KcED6auWfmpnCV+B/pT2RTzFkNY394shrGJlhismYEDgqoKuWz/NSo/aiE8Bk/DCg2bGQrL9ZAhDkZZPqXHYMlEkzIwj0jwaTaVLBSdUNGtUaVkEUlsGck1VgCml4hxdCK/IVAukclmFBBwCo2o6o1qoQsqmRR77uhShYhQgpXijCMB8ClzecoRzq7PkAulgFSKV+xWZslameBERuyEQHhn6zNCyoBfu3n7Pf37rmHS0L+WRyJedLZqFhGCUFGiQoYAnuWEITBjVmbPs76nMlV0DhrAXjiKC62eiyM2DLjj/i52p+Mb5bwQMWNkbFt2URjHT0Pc1bDeAB4shrGJpCffF5/hvKfq5wvG8uqVe9nVpofWTCAwVlXFvJPO0UO+U6u2vDNiqXGzKJeUX9mbsTOxI8857M2sn4Y46Ram63RaucwZzWMveHJahib4NJkZS321//u0Wp6vWBV5Fchqug4s9J8/cxYNkj8vc8w8psV8rfe27M+3zshIQv5kY9IUgetbM4Qy486UWDUYv+IZEumjyOrG44S8kfJIsSPLMGEJFtUC/cssXwJkgYqPysJ4XttMn1kVeFYScjfGoOrfooTTIaxPy4J+RFkeGKGl0SxodX1679mmRMFkHGq/SjiZ1XOR58ZAiR+BMgGlLM22BsCmjEfHZxVJeQfcRR9y05lw6rIP0o8f5WPZuJdfSM5wsEVfFpWkX+kQP2sjeMYx59Vm9xV/A3xE40Bi+Otzj/P2lXx6WLOahj7w5PVMDaBfJ1VwZGQImutz1ibCpCK/Ao/o9aaWRv7WTYMLjlinKYWTBsphD9royR5YYZ/1u0iXHKWn1FrzayN/SwbZGwZsSF+Wja38OZzw3gAeLIaxia4NFmRSvMtIAvs9VeEqKJg67MW6nYZFfkzflQ2GawkioieYQvRO4j0OWPTeleUeJhKEbMqUKxekYKRoMkc+chK4lxNJM6sAokK+adV5K+hSrbUmFWBAu1zb/zKpE40biohiyKROFIU0bKJxqAV2y2cYDKMB4B8siLi/wxPzNggyHCxkeLtM2RiUXFLBli5iQisd6NuozW2Z/dc5blyzooI1FmL/aqTwUaI/0cK4a9WLszEO+rkBVZFRFQUcRZLxs9Uzjpzsb9ulyWWr6+jWFScVSVqzzyPTLvRMxvVZ5R/RtetMbjqx5zVMB4AnqyGsQnom89HrYv1cstWLLPWfFnjNKoKvirPoOKsqo39qvyGpCJ/6eQ/Rbguxohl1pova5xUonbVc0bGKWPD8DPiHcyszX6YkuashrE1PFkNYxN0TdaMKL/+142ICjKL2LWfjA0i5EdiGzVOiKigBeR5ZOI/8/P01D5RIAJyCkELvX5RGyaWqBTBqDQ/qgIfqyKFqjr9jORXpo9KsQLjFIJVqvhPrcivWIRHhNcZ35n4Zwr5Ff1hJI8yfVSKFZA+R7GpBCet/tTtHE4wGcbe8GQ1jE1A56wRr5q1kbnlGz0dTbHhnnUi3IxK88x2EZ7OqFTIsjnrc6Y/QzlrCb7Hz9rI3PI9c/N5FBvKxZD+9D5DZbtXY8m+Gyqbq/0p5qyGsT88WQ1jE8iPfMwcmVj/u8+gt+pdyzdLSMGoooj0GTmWMIM6NtYxnZl2I2REEUifkQohmUoRzDEYvusGFQggC+zRwj3LZtYRg8ium5GCk1HVDWdVRGwli876nBkDWqWI0pkgKGBSh7XAHrXLsukdF1aCRmXTG1umj0hSUNlnpD+Rn6jPmTFwpQjDeADQJ2v91wDhuS0gYu2I47FsrvYH5XwZIDY10IqIDIE9o8/IhgfkJIaIP1+tJEnnrKqqfYwKfKxqDKMqXTA4n0oIv8ppBywh/yhRxLDqhqWTpxwHrxrA2XVJckmEu4zkO2fjgnI+ho2ySkLPNTN+xE/G5uoYmLMaxgPAk9UwNoGcsyLf/RUbsVufoZuqZ/kZVekPWWeduf6p2hTOqMjf25+h1Q0Z3/0VG7Fbn42KheUnY6Pwg3L7aJxY+QDEhvHetvxc7c+HKWnOahhbw5PVMDbBpeqGiBCbVd2wN7bWZyzUfjLIiPLra1V1wKgNdHEfaVfxbmSQEf9nxoAhvrgbY/RyPV088hGp9McSeDMWuhmJn5mCjRFjgMaPiEdm+mEcTTJMyI8c+Vh/NlLUnmk3smEkfmYKNmooxgCNX1FpUeknsmGILyyKMIwHgHzzeWbDd4Yv1DaRnwyvYgmtEfF//VeTBWQjNmMMVGJ5FRh+Wv2J0NrAkcUlUYRq87ZisbzVzsgTBRhcUiVEUAn5GWL5Ee8G87SAs3syQgqZkB/hOyr+hnDhkZsKonFicVYktqtjkBl/ZZ8jG1Yxg8z71eunvsec1TAeAJ6shrEJhq+zsrhYr6i9dQ/Cn1WnELB4ruJENdYpcrMq5atOuGPx9iGnyM3kYjVGrX+yhPyZPp9d37vn6jghsWXbzcSvsFFx1ui5Z8egbucwZzWMveHJahibgD5Zo0V4dsW3v5BZUGecDsBauI/8ZJCpiJgRRSBA2lXFgiB67qjIQxpz9KL0JpgYCRpW1XjVpoLeRfiRmwyQqglXky3MdlfdjJEV3USiiKUSTIWQoEEXpM+uW58pKwZE/UH81GBVTYhsRgoPRo1T9NzR6hLRGEQ2xQkmw9gf8or8iA3CxTIcg3GKnEr8H93TEuVnTlTLPA/Epr4nM5ZILDVUzyODzDsYYZiQf9TmcwUXa7U7s7rhqqJ8Vj5gpKhj1KkQinyArLrhqM3nCi7WapfFd6JYUFFE5IcliohsVFxyVPyqypG972DGppizGsb+8GQ1jE0gX2dVidpVJ58zqrmzThRAquuPOnlupY3kqxQDiDhq5rlPXWddpXjYcYw7xZzBc9Hq+pHNqDVs1XNeuRhAxFEzz72YsxrG/vBkNYxNMHyyIuKFjOA+skE3FaxUxT+KTSUkR8a2hcwzi2yidrNjgLyDUWyRKOXeKRFZDE8wMY5vZFbg601OsCoiMo5vXEWIUAqWxGEdWamoVIg+j62F/KxkS31dg9VubcMSdSA2Z2N9b/wRG8bYIrGokmqZ2FTP46pNcYLJMPaHJ6thbAJ6Rf5VK6hnY1HwQlZF+1mnsK20eX71DQ9bVeTP3BPZqDZIqxbHIxuWKCJjE/lRCU5UnLX3Oc/e8IA8w1u4Ir9hPAA8WQ1jE1zirDud3JaNZaXq+quI2llF7hg8fZUN96iQP+K5tHXW0smz/vqeHnEBhC/UsYzikjO5cWb8r45Tlltm2u0d25Frpr3vbfZ59MbmdVbDeEB4shrGJpBP1kxFwVFV45HqhoiQv/bDQv21qBX/2fWRFM9nnkfdLgJWRUQk/sgG2UwSxX91o4U8waSqztCb4GjZjRJ1qEQFq1Txv+cbESswNiIwNnAgm0nQZ7hMgklVnaG2QdodJeoYJb5QCRHQhAzyzJA+R35GbSZB+1P7PpxgMoy94clqGJtALuSfWY0OOdFuFudenVfNqEi5Y0XErYX8M6vRRe2uxLlX51W9Y7uSzch38Gz8M0IKC/kN4wHgyWoYm4DOWRmnsCl4Sesz1toso/I/i78x1jZZm8+RSv8MUf6ojf2Zzee9scnWWdF1vbPrTLsIL2l9plozHbWeq1rbRJ5zJhaGDYvbR31mbT7vje3DlDRnNYyt4clqGJvg0mRlVWrvbfeeiLq+rpER5dftqrCyn4xNRvgeIXqu6CYPxthmNklEvjMbK3pOVaAL+Vda7GdUcMgkixg2SGUCpGqCShTBOkVBkfhhVVEcsRlDJopYfbG/bhdZ6FYlmDLxI/1BxgmJjfHMRiV+WBUpkHHqtbEowjAeAJ6shrEJ6KfIMQQCrKp3qlhmnag264Q+RgVEVgVB1QkPDDEP4kfGWTOnyNVgCeEjm5GxnF1nYsnEpqrIrxD/34v/arss8UhmbJH+MPyYsxrGA8KT1TA2wfCTz0dWtB+1zsroM4tXKdYPVafgqTYvME4xV50itxRnLcH3eNV66Mx1VkafWbxqxBgoq+sjNmd+s+PU2+eskD/yY85qGA8IT1bD2ATDJ2tGlJ8R8iMC70y7COr4EUQnFbT6nLHJjC0SW91uJpYMEJvar2psz2LNbERo+el554YnmFDhtaIiP+tEAVWlCKTqgyLZoqqUr6p0MSu2zJGPkc1SCSZUeH12XS4kcXpjUSWlMokTxCYztlF/GAk+NP6MzSqxRUm1jI0TTIbxAPBkNYxNIN98vsrC9717FDZ1H1miglFic9YpcqPE/8gpcghPV/mZdopcfc8qC9/37lHY1PGrROAzK/1FfWaILdCKiIyxHemnbucwZzWMveHJahib4NJkzSyW12AJETKIhBPRYnlrcRwBS1SQWbjvfR5o/BlRCiN+FTJihUjwgPi5FHP0AM8STEjVvkyyhZVgYhyfgRwTqaggOLMixUhRh2KcWBU1Rhy5MfX4DCTZwkow1e0yxAqqZMsogQBrdwzyzBjxsxI/iE00BoydOsUJJsPYH12TteYUqor8Ldz6zXKm+i8TUpG/7lMGDJsWR4piQ/MBkU2mXRbHi5CJJXruKm5f+0bF//dwSciP8E+Us16ttHCvHcXCPUP8r6pmwBJSjKoUMaqKIqva5LJCfoR/qoQUKy3cM8T/qmoGioqIGd9Znqvo8yhubyG/YRilFIsiDGMbDK/Ir6oav5KQf+bmc8UJAyyB/ShR/ionwmVtJOusjIr8DI60upB/5ubzyI9izTHje6QoHxmnTGwqmzrew5zVMPaGJ6thbAL5ZGUI4RE/mUV4RKCOitpr1F9xkMV+RBShEstnRBG9QhZ0wwMytpn+ZMT/EaaJIhDSjVZNUBxLOOooDFQUMaMKJJpgGnWUB5LIUglOMtUNM+/KtOqGLQJ9C1WCCUl2IQkZlk0Um0rgoEjq3IslGieVkB+JLWOTGaeeWFs2FkUYxgPAk9UwNoFcyM/in8gCO2PRGhHlZ3j6qA3rDCEC0mcWl2Rt8J41tstw1pH8E7HpvQfluZGfmRvWEZv6Gc7kkqr3SdGfjJ9Wf25hzmoYDwBPVsPYBPJT5EZV5EdE7CgXG3WK3Cob1kcWfFMUP1O9g6zCbMucIod8j1dwsdLgNzN5LjJOkR/VhnUWx0Oec/QMVcUMRsVmzmoYDwhPVsPYBPTJWv/rRmyQSnmRTTaWSGzOOlEg02eGKB/ZZJCpTj9q8wLS58xzZ1RrrP0oTxAoZYKQf2TiRFHBgeVnVNU+VaWIEacosPyoKlJEVR+WTzDVmJk4ie5BEg0jqxtGftBF+J7r7D0sUYTCDytZdNZmpl0nmAzjHwJPVsPYBMOF/LNE4a17EC7J4rmqE+EUm8RX2giv4qyzTgKsP5NVN0T4jkpEPYpLsnhuZKPiuSzxPBJLps8Mmxosm7PrrJ9obD9MSXNWw9ganqyGsQm6OGsp5XUp5ZbQ/LeU8v7kOnPPKJuVYnk0m5Vi2d3m9XEc/ykt3Pt+nPkppbw9u87cM8pmpVgezWalWB7N5vbHX4MNYxN4shrGJrg6Wf8XXGfuGWWzUiyPZrNSLI9m8xFhgskwjDXgr8GGsQk8WQ1jE3iyGsYm8GQ1jE3gyWoYm+BPLWfeuIX+4NIAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] @@ -43,15 +43,15 @@ ], "source": [ "np.random.seed(0)\n", - "init_state = State(atm.get_random_state((30, 30)), Rules.CONWAY)\n", + "init_state = State(atm.get_random_state((50, 50)), Rules.CONWAY)\n", "#print(init_state)\n", "ans.plot_state(init_state.values)" ] }, { "cell_type": "code", - "execution_count": 8, - "id": "30a5df76", + "execution_count": 61, + "id": "aa09247e", "metadata": {}, "outputs": [], "source": [ @@ -66,26 +66,31 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "2f56a560", + "execution_count": 11, + "id": "b0d18019", "metadata": {}, "outputs": [], "source": [ - "from tqdm import tqdm" + "from tqdm import tqdm, trange" ] }, { "cell_type": "code", - "execution_count": 9, - "id": "69b95a7e", + "execution_count": 63, + "id": "58703675", "metadata": {}, "outputs": [], "source": [ - "def run(init_state, action_func, reward_func, episode_length = 10):\n", - " agent = QLearningAgent(action_func = action_func, reward_func = reward_func)\n", + "def run(init_state, action_func, reward_func, episode_length = 10, num_training = 100, num_testing = 100):\n", + " agent = QLearningAgent(action_func = action_func, reward_func = reward_func, num_training = num_training)\n", " \n", " state = init_state.copy()\n", - " while agent.episodes_so_far < agent.num_training + 50:\n", + " \n", + " train_rewards = []\n", + " test_rewards = []\n", + " \n", + " pbar = tqdm(total = agent.num_training + num_testing)\n", + " while agent.episodes_so_far < agent.num_training + num_testing:\n", " \n", " agent.start_episode()\n", " for i in range(episode_length):\n", @@ -97,572 +102,90 @@ " state = state.get_successor(action)\n", " \n", " agent.stop_episode()\n", - " print(f\"Train Rewards: {agent.accum_train_rewards}\")\n", - " print(f\"Test Rewards: {agent.accum_test_rewards}\\n\")\n", - " " + " \n", + " if agent.is_in_training():\n", + " train_rewards += [agent.accum_train_rewards]\n", + " else:\n", + " test_rewards += [agent.accum_test_rewards]\n", + " \n", + " pbar.update(1)\n", + " pbar.close()\n", + " \n", + " \n", + " plt.plot(np.arange(agent.num_training-1), train_rewards)\n", + " plt.plot(np.arange(agent.num_training, agent.num_training + num_testing + 1), test_rewards)\n", + " plt.show()\n", + " \n", + " \n", + " return agent" ] }, { "cell_type": "code", - "execution_count": 10, - "id": "a4147c0f", + "execution_count": null, + "id": "863b67a5", "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "Train Rewards: 0.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -284.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -328.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -328.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -308.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -315.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -330.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -328.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -298.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -296.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -322.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -338.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -324.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -366.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -355.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -328.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -323.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -300.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -326.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -302.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -290.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -277.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -253.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -242.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -236.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -301.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -304.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -300.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -280.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -284.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -318.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -302.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -287.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -260.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -285.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -260.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -222.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -228.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -257.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -264.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -255.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -233.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -256.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -248.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -264.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -276.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -282.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -263.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -229.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -251.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -283.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -265.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -276.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -281.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -264.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -222.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -232.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -239.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -259.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -261.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -278.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -268.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -253.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -257.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -262.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -271.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -255.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -230.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -235.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -256.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -232.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -237.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -228.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -243.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -240.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -240.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -240.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -240.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -240.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -238.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -236.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -241.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -231.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -227.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -205.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -202.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -199.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -195.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -195.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -198.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -203.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -216.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -213.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -223.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -222.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -223.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -223.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -223.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -223.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 0.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 1.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 33.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 27.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 37.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 48.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 29.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 41.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 41.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 52.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 12.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 38.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 20.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 15.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 4.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 4.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 4.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 4.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 4.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 4.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 4.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 8.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 8.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 8.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 8.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 8.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 25.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 6.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 25.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 20.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 30.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 8.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 11.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 23.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 16.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 18.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 23.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 25.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: 19.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: -4.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: -22.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: -22.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: -22.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: -22.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: -22.0\n", - "\n", - "Train Rewards: -225.0\n", - "Test Rewards: -22.0\n", - "\n" + " 18%|███████████████▏ | 73/400 [02:50<12:51, 2.36s/it]" ] } ], "source": [ - "run(init_state, action_func, reward_func)" + "agent = run(init_state, action_func, reward_func, num_training = 200, num_testing = 200)" ] }, { "cell_type": "code", - "execution_count": 25, - "id": "1e634653", + "execution_count": 19, + "id": "f08dd2a2", "metadata": {}, "outputs": [ { "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAMW0lEQVR4nO2dMW4WTRKG26vNYEWykgUih9x3+k/glBtwgj0Bt+ACOIccgZA2hHg2+NeZu/tz1bxT/c73PJKj9kx19cwr2fOqqm62bWsAsD7/qN4AAFwGYgUwAbECmIBYAUxArAAmIFYAE/75nF++ubnp+jzv379vL168eHLtz58/w7WvX792Y2buG1lT3XeW5+vXr9vPnz+fXJudQe++o3vO1jNrb9686cb88eNH6FrV2WZiZt6h0X63bbvpLVz801rbej+fP3/eeszWVPeNrKnuO8vz48eP4TOI3HO2nlkbEb1WdbaZmCMy+906+uPPYAATECuACYgVwATECmDCs8R6d3fX/fj08PDQbm5unvx5eHhQ7X+IYj+zPEdrvbPbksUUvXuOntdsfba2Gpk8o+9J77rMuzA625vZi3Jzc/NXa+2v1lq7vb29+/Tp05O/9+vXr/b9+/cn196+fdtub2+fXPv9+3f79u1bN/67d+/ay5cvu9eO1nr3He1ndt9ZnpEzmN03cwa9tcy1mbPNvCeqPCO5zN7b6Ltwf3/fvnz5krdu7u7uup+jM5/He9c1kSU0sxdG962wNFa0qBR5rmajZN7baJ7/1xjWDYAziBXABMQKYAJiBTABsQKYYG3dZGyU3lrm2qj10Joml4xFtZqlobLiFGuXXNvL5f7+fp+qm9WsG0XVSObaqPWgyiVjUa1maaisuCpbbHRGG9YNgDeIFcAExApgAmIFMAGxApiwm1hVZUgqent93G9vLcOopGpEplzt6DKuGSuWqyliKtjNZ814diPPU+WzqkrkFP7jrEQuerYKzzhb8hjJpco3V+R5iM+q6sBXUa6WiRk5g9l+q7obRs4nW/IY3WsmT8XzxGcFuGIQK4AJiBXABMQKYMJhYu390zzrwKeKGSVjPYzyVHWOzJytwu6YUfHM9t7rJWcb2c9h1o2i1EjRKS+z3zN1VMzYHVErrsL+y5QCKt6hQ7obVpQaqcrV6KioGVrVLrA0omebibnSO0R3Q4ATgFgBTECsACYgVgATniXWTPVC9FP/6FrVJ/lRzAyqKg3FGVTFPLoKSLFXlR6eZd28evXq7sOHD0/+nqoDX8XApkzMiuqiio6KR1ejuMWM6mG3qpt2cGXDttUMbKqogKmwqJyqUdxiZuytDesGwBvECmACYgUwAbECmIBYAUxgMNWO115Tw7TVKmBUw7CijeHKrRsGU+ksDbeGadGYK9ooo/uqYo6u3bBuALxBrAAmIFYAExArgAmIFcCE5QdTqcrVenvdJlbWigO4RvuJ5tlavFwtE3O1s42iKN9cfjCVosuesixPkWdFKaDCm8zs91oGcFkPplJ14BuxmudZUQro1jky8w7tvZ/Me9LwWQH8QawAJiBWABMQK4AJu3U3zND7h3qbWAGZtWinwVlnush+ZuuqQVCK8xvtNfuenInR2fbYrbuhovTp8dq9bQtVeZPK0nAbwLXaYKoVrRv5YKom+FSdGdiksAEe9xRdi+xntn4mGwXrhsFUAKcHsQKYgFgBTECsACY8S6yjz/kzetfNbBTVwCtVnlEytpiiokllUanIxDz6/YpyWNXNajaKqqNixZAoRXWRW9WNKs9oh8xy6yZTdTNaawU2yiim25CoyH6yMaN5qqwbVZ6Rtdk7hHUDcHIQK4AJiBXABMQKYAJiBTBh+cFUFY3EVqv0ccvT6WxVtlg0l0OsmxUHNkXWMteqLCq3PCtiRs9WZYth3QBcMYgVwATECmACYgUwAbECmLBbd8PMIJ7edY/X9tZUZErHoiVVmTwrSuQi+1GW+10Dh3U3XM2XqyhXq+j6V1EipxjYpPLqT+uzNkFJ0IpeoFNJVVXpWGQ/l+SpiInPCgCHglgBTECsACYgVgATDhNrhb0QWbvEhoquRXEbTNW7bhPabao8R2QsqgjLD6Zys1EUg4wUFssl+60YNKawbiqGYZV3N2wHf5LfNj8bJXIGs5ir2QuZ51lh3ajyVMTEugE4AYgVwATECmACYgUw4TCx9v5pnn0er7BRVEOievvZti0VM5pnhlHMjKWhyCVjfVWcbY+rrbpRdRpcrQJmtbOtsMUU56fKcwnrJrKWubaq0+DoniPchkSN1o5+T7JVNyvliXUDcAIQK4AJiBXABMQKYAJiBTDhsMFUigoYVUynBlsVlSFKS2O1qpuj87y/v2/bttUOpoqstckncFVMpwZbFZUhSksjcn4r2kWZazesGwBvECuACYgVwATECmACYgUwYfnBVG6lYyNGpWOqIVGK/UbLymbPc1RGqHqHFOVz28QOHV07yvPU3Q1XG2RU4e1mfNao55kZ+hUtV1MNGju6/HCJErnedSsOT6roqLhiF8fo2WaeZ+S+2ZiKPCmRA7hiECuACYgVwATECmDCqcUa/SSv6qjYW5utzxjds4LM0K/IfWeDxmYozk9hF53aulltMFX0U3/G7jhTuVq0FFDV3VBhF+1WItfMrJvI2iW57L02W1dZVG7latHn6WQX/S1JrBsAaxArgAmIFcAExApgAmIFMGGJ7oYV1k00ZoVFpbA7Hve0tyWUfZ7RmKqKpqOri3arulF1N1R8AlfFrLCoVhzYpHqe0ZiZPKP7UeRJ1Q3ACUCsACYgVgATECuACYgVwITDqm4UA35WHEyl+tS/WnWRahiWokmbolLq1A3TRmvR+67YvCxyBqpclNVFqpiKs12tioqGaQAnB7ECmIBYAUxArAAmIFYAE54l1kwHviiZYViRtdn6jIoBUqM8MigGSGUGU0W7G2Y6DUafp2KY2G4lchm/KlrGpehCOFtfsRTQqbuhqiyvYuiXohTwkBK5ijKuFTsNjtZGVAzgquhu6FYip3qePfBZAU4AYgUwAbECmIBYAUx4llhVn7l7/1BvQhslM8goksvM0sjYUBV2UcR+uSSXCvY+90tsxdLBVKrSMYUllCnjynTnc8ozet+ZpVGRp2owlcIWO2Qwlap0bLUyrmvJM3rfmaVxpsFUirP9W5JYNwDWIFYAExArgAmIFcAExApgwvLWzWpDolQVMKrKEFV3w6O7/mViOnWOPKS7YUXXv4ohUaoKGLfuhpE11X3P1DmSqhuAE4BYAUxArAAmIFYAExArgAmHNUxbzUZRNUw7S8yKBnhVDdOiMU/bMK0tZqOMWK3qZsVKn9Fab6/ZZxa5pzLmCBqmAVwxiBXABMQKYAJiBTABsQKYsFt3wxm9L1yZLnEzKrobRvYzi1nV3fDobpWz7obR6zJEYy7d3TBTbhXtEqcofcrGVA2JUnjYivuezdtVlQKWdjdU+XKqIVGqmIo8VR62ajBVJubee73kXdh7P9n9bvisAN4gVgATECuACYgVwIRniXX0OToztKp3z9l9FTbAdoElFCWaZ4aRXZQZJlYRU2VRKWIqzna3EjlFKdvsvqruhhXlaqo8M6VjFV0cr6VErre2m3UzKpFTlLLN7lthabh1ceytXfJconlmYlbkGY0ZfZ6z/W5YNwDeIFYAExArgAmIFcAExApgwvKDqTI2StQucrNuFBVNqqobpyFRGbvoKgdTqapuVDGjeWasm9Ha6JlVVN04DYnK2EUMpgK4YhArgAmIFcAExApgAmIFMOEw62a1qhtVwzQne6GiAkZlF1U0TFPEXMK6Ga2NqKhGWa1JW1XVzdFnm7GLVHkeHRPrBuAEIFYAExArgAmIFcAExApgwm7dDTOdBqMDmzId5Hp73YSdBqO5zIg8k+wzUxHdT0Weo5gz5IOpRt0Nq7zAo71dVRnXaD3jwaqGRB1dOlaVZzTm0t0NV+zAN1obUVHGpSqRG+FUOlaVZzQm3Q0BrhjECmACYgUwAbECmPAssa5mL1SQsa9UeUaGQF0yJCoyTGxmhWRiVgymUhE5v91K5CrshRWHYTnZKCsOpjq6W2WVdVNaIlfx2X3FsjxFnivaYpH9ZGNGn+eK1k0PSuQATgBiBTABsQKYgFgBTECsACZYD6ZS2B2P+12po2JFnlVd/44e+pXp4pip3Fq6u2H0E7jqs3s0F7dqlKOfmdK6iexn23LDxCLXZWJi3QCcAMQKYAJiBTABsQKYgFgBTHiWddNae9da632X/3dr7b87r6nuS0xirhrz3bZt/3pypfeZ+Lk/rbUve6+p7ktMYjrG5M9gABMQK4AJe4r1P4I11X2JSUy7mNMPTACwBvwZDGACYgUwAbECmIBYAUxArAAm/A8vcPUave7SMwAAAABJRU5ErkJggg==\n", "text/plain": [ - "array([[1, 0, 0, 1],\n", - " [0, 0, 0, 1],\n", - " [0, 0, 0, 1],\n", - " [1, 1, 1, 0]])" + "
" ] }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent.observe_function(state)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "1f686987", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0" - ] + "metadata": { + "needs_background": "light" }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "action = agent.get_action(state)\n", - "action" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "d0093739", - "metadata": {}, - "outputs": [], - "source": [ - "agent.last_state" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "7948263e", - "metadata": {}, - "outputs": [], - "source": [ - "agent.do_action(state, action)" + "ans.plot_state(init_state.values)" ] }, { "cell_type": "code", - "execution_count": 21, - "id": "d9b8a338", + "execution_count": 41, + "id": "a23d5df3", "metadata": {}, "outputs": [ { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAEvUlEQVR4nO3XMWobaxSG4fNf0kXgxhACLtwp/dwNZDNagVrvQNnAXYF3oSxAWcDtUoiEgEv35xY3pZ3YIGnyDc8DfyFminOQXjQzuruAP99fcw8AvIxYIYRYIYRYIYRYIYRYIcSb19x8fX3dt7e3ZxplPt++favv37/PPcbJvX//fpF7VVV9+PCh3r59O/cYJ/f169d6eHgYT17s7hefaZp6iXa7XVfV4s5S96qq3u/3c/9szuJnY0/25zEYQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQrwq1i9fvtQYY3EHEozu/vUNY2yqalNVdXV1Nd3d3V1irou6ubmp4/E49xgnt9S9qqrW63WtVqu5xzi57XZbh8Ph6X+Q7n7xqape4tntdrPPYK/Xnf1+30s0TVP3M/15Z4UQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQYoUQo7t/fcMYm6raVFVdXV1Nd3d3l5jrom5ubup4PM49xsktda+qqvV6XavVau4xTm673dbhcBhPXuzuF5+q6iWe3W43+wz2et3Z7/e9RNM0dT/Tn8dgCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCCFWCPHmNTdP01SHw+Fcs8zm8+fP1d1zj3FyS92rqurTp0/18ePHuce4qPG7L3OMsamqTVXVu3fvpvv7+0vMdVGPj4+1Wq3mHuPklrpXVdWPHz/qeDzOPcbJbbfb6u7x5MXufvGZpqmXaL/fzz3CWSx1r+7u3W7XVbXI08/0550VQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQogVQrz53Q1jjE1VbX5+fBxj/HvekWZxXVUPcw9xBkvdq2q5u62fuzC6+5KD/JHGGIfu/nvuOU5tqXtVLXe3X+3lMRhCiBVCiPV//8w9wJksda+q5e727F7eWSGEf1YIIVYIIVYIIVYIIVYI8R9a6B1zOUX4ggAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" + "name": "stdout", + "output_type": "stream", + "text": [ + "20. Test Rewards: 12.0\n" + ] } ], "source": [ - "state = state.get_successor(action)\n", - "ans.plot_state(state.values)" + "if agent.is_in_training():\n", + " print(f\"{agent.episodes_so_far}. Train Rewards: {agent.accum_train_rewards}\")\n", + "else:\n", + " print(f\"{agent.episodes_so_far}. Test Rewards: {agent.accum_test_rewards}\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "63dd58d7", + "id": "d71f9f94", "metadata": {}, "outputs": [], "source": [] diff --git a/requirements.txt b/requirements.txt index 653ad62..c4343d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ pytest numpy pandas matplotlib -pathlib \ No newline at end of file +pathlib +tqdm \ No newline at end of file diff --git a/src/machine_learning/train.py b/src/machine_learning/train.py index ad0ec9a..0c6e425 100644 --- a/src/machine_learning/train.py +++ b/src/machine_learning/train.py @@ -1,3 +1,6 @@ +import numpy as np +from tqdm import tqdm, trange + import sys sys.path.append('../') @@ -7,22 +10,23 @@ from machine_learning.q_learning_agents import QLearningAgent from machine_learning.state import State -def run(init_state, rule = Rules.CONWAY, **args): - state = State(init_state, step_num = 0, rule = rule) - - agent = QLearningAgent(action_func = ..., reward_func = ..., **args) - - while ...: - - # observe the current state - state = agent.observe_function(state) - - action = agent.get_action(state) - - agent.do_action(state, action) - - state = state.get_successor(action) - - - - print(f"Rewards: {agent.accum_train_rewards}") +def run(init_state, agent episode_length = 10): + #agent = QLearningAgent(action_func = action_func, reward_func = reward_func) + + state = init_state.copy() + while agent.episodes_so_far < agent.num_training + 50: + + agent.start_episode() + for i in trange(episode_length): + _ = agent.observe_function(state) + + action = agent.get_action(state) + agent.do_action(state, action) + + state = state.get_successor(action) + + agent.stop_episode() + print(f"Train Rewards: {agent.accum_train_rewards}") + print(f"Test Rewards: {agent.accum_test_rewards}\n") + + #return agent