diff --git a/nb/automata-analysis-v0.ipynb b/nb/automata-analysis-v0.ipynb index c6e7110..e451de3 100644 --- a/nb/automata-analysis-v0.ipynb +++ b/nb/automata-analysis-v0.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "d2c27cd8", + "id": "df406f50", "metadata": {}, "source": [ "*This notebook is my own notes / development playground. It will be disorganized in nature. Feel free to skim through.*" @@ -45,7 +45,7 @@ { "cell_type": "code", "execution_count": 24, - "id": "8de635ad", + "id": "50a54b48", "metadata": {}, "outputs": [ { @@ -98,7 +98,7 @@ }, { "cell_type": "markdown", - "id": "21a1425a", + "id": "2cf55a8a", "metadata": {}, "source": [ "### Cycle Detection\n", @@ -136,7 +136,7 @@ { "cell_type": "code", "execution_count": 24, - "id": "57e9c62a", + "id": "5cdf66b1", "metadata": {}, "outputs": [], "source": [ @@ -146,30 +146,32 @@ }, { "cell_type": "code", - "execution_count": 26, - "id": "f11c4e73", + "execution_count": 41, + "id": "5e6fa5b4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([[1, 1],\n", - " [1, 1]])" + "True" ] }, - "execution_count": 26, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "atm.update(state).astype(int)" + "np.all(atm.update([\n", + "\t\t[0, 1, 1, 1, 0, 0],\n", + "\t\t[0, 1, 0, 0, 0, 0]\n", + "\t], rule = Rules.CONWAY) == [[0, 1, 1, 0, 0, 0], [0, 1, 0, 0, 0, 0]])" ] }, { "cell_type": "code", "execution_count": 27, - "id": "8cd5ae61", + "id": "f3cd690c", "metadata": {}, "outputs": [ { @@ -191,7 +193,7 @@ { "cell_type": "code", "execution_count": 28, - "id": "269928a7", + "id": "80f43b73", "metadata": {}, "outputs": [ { @@ -213,7 +215,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "36f320fe", + "id": "159388c7", "metadata": {}, "outputs": [ { @@ -239,7 +241,7 @@ { "cell_type": "code", "execution_count": 13, - "id": "9de1f071", + "id": "0212826a", "metadata": {}, "outputs": [ { @@ -264,7 +266,7 @@ { "cell_type": "code", "execution_count": 24, - "id": "3e8a6ca5", + "id": "fb5eadd6", "metadata": {}, "outputs": [ { @@ -292,7 +294,7 @@ { "cell_type": "code", "execution_count": 22, - "id": "401e873b", + "id": "7431a6bf", "metadata": {}, "outputs": [ { @@ -327,7 +329,7 @@ }, { "cell_type": "markdown", - "id": "6b08f1f5", + "id": "12fcf5e0", "metadata": {}, "source": [ "### Testing Rule Integration" @@ -336,7 +338,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "cbec3ee3", + "id": "5ca955c4", "metadata": {}, "outputs": [ { @@ -358,7 +360,7 @@ { "cell_type": "code", "execution_count": 16, - "id": "d7689c6c", + "id": "18b63b0d", "metadata": {}, "outputs": [ { @@ -382,7 +384,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "855b4395", + "id": "3ccd93ea", "metadata": {}, "outputs": [ { @@ -417,7 +419,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "f3549ae1", + "id": "50999995", "metadata": {}, "outputs": [ { @@ -455,29 +457,29 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "4d8c8947", + "execution_count": 34, + "id": "9fa22e77", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(1, 1, 7)" + "(8, 2)" ] }, - "execution_count": 5, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "(1, *state.shape)" + "atm.get_neighbors(1, 1, (3,3)).shape" ] }, { "cell_type": "code", "execution_count": null, - "id": "19276fc4", + "id": "16cb1496", "metadata": {}, "outputs": [], "source": [] diff --git a/src/analysis.py b/src/analysis.py index 87b4cf8..64e78f7 100644 --- a/src/analysis.py +++ b/src/analysis.py @@ -3,21 +3,21 @@ from rules import Rules def get_total_alive(state): - if type(state) != np.array: - state = np.array(state, dtype = int) + state = state_fix(state) return (state == 1).sum() def get_total_dead(state): - if type(state) != np.array: - state = np.array(state, dtype = int) + state = state_fix(state) return (state == 0).sum() def is_terminal_state(state): + state = state_fix(state) return get_total_alive(state) == 0 # ------------------------------ def get_survival_stats(states): + state = state_fix(state) return np.array([get_total_alive(state) for state in states], dtype = int) # each cell represents how many alive neighbors it has @@ -25,6 +25,8 @@ def get_survival_stats(states): O(8n^2)... whatever ''' def get_alive_matrix(state): + state = state_fix(state) + x, y = state.shape alive_mat = np.zeros(state.shape) for i in range(x): @@ -62,6 +64,8 @@ def display_state(state): Use matplotlib to plot a state, works for 1D as well ''' def plot_state(state): + state = state_fix(state) + fig, axs = plt.subplots() if state.shape[0] > 1: axs.imshow(~state[0], cmap = 'gray') @@ -78,8 +82,7 @@ def plot_state(state): # ------------------------------ def play(state, steps, rule = Rules.CONWAY, verbose = False, verbose_func = display_state): - if type(state) != np.array: - state = np.array(state, dtype = int) + state = state_fix(state) i = 1 states = np.zeros((steps, *state.shape), dtype = int) diff --git a/src/automata.py b/src/automata.py index 122a2e4..83787b8 100644 --- a/src/automata.py +++ b/src/automata.py @@ -28,8 +28,7 @@ NOTE: this is *not* updating in-place ''' def update(state, rule): - if type(state) != np.array: - state = np.array(state, dtype = int) + state = state_fix(state) x, y = state.shape new_state = np.zeros(state.shape, dtype = int) @@ -63,4 +62,12 @@ def get_neighbors(i, j, shape): def in_range(i, j, shape): if i < shape[0] and i > -1 and j > -1 and j < shape[1]: return True - return False \ No newline at end of file + return False + +def state_fix(state): + if type(state) != np.array: + state = np.array(state, dtype = int) + if len(state.shape) == 1: + state = state.reshape((1, -1)) + + return state \ No newline at end of file diff --git a/tests/test_class.py b/tests/test_class.py index ada2889..f69eaf5 100644 --- a/tests/test_class.py +++ b/tests/test_class.py @@ -15,3 +15,63 @@ def test_tautology(): # ------------------------------ AUTOMATA.PY +def test_in_range_0(): + assert not atm.in_range(-1, 0, (4, 4)) + +def test_in_range_1(): + assert not atm.in_range(0, -1, (4, 4)) + +def test_in_range_2(): + assert not atm.in_range(-1, -1, (4, 4)) + +def test_in_range_3(): + assert not atm.in_range(2, 5, (4, 4)) + +def test_in_range_4(): + assert not atm.in_range(5, 3, (4, 4)) + +def test_in_range_5(): + assert atm.in_range(2, 2, (4, 4)) + +def test_in_range_6(): + assert atm.in_range(2, 3, (4, 4)) + +# ------------------------------ + +# check shape correct +def test_get_neighbors_0(): + assert atm.get_neighbors(1, 1, (3, 3)).shape[1] == 2 + +# check one valid +def test_get_neighbors_1(): + assert atm.get_neighbors(-1, -1, (3, 3)).shape[0] == 1 + +# check none valid +def test_get_neighbors_2(): + assert atm.get_neighbors(-2, -2, (3, 3)).shape[0] == 0 + +# check all valid +def test_get_neighbors_3(): + assert atm.get_neighbors(1, 1, (3, 3)).shape[0] == 8 + +# check center not included +def test_get_neighbors_4(): + assert ~np.all(atm.get_neighbors(1, 1, (3, 3)) == [1, 1]) + +# ------------------------------ + +def test_update_0(): + state = [0, 0, 1, 1, 0, 0] + assert np.all(atm.update(state, rule = Rules.CONWAY) == [0, 0, 0, 0, 0, 0]) + + +def test_update_1(): + state = [ + [0, 1, 1, 1, 0, 0], + [0, 1, 0, 0, 0, 0] + ] + upd = atm.update(state, rule = Rules.CONWAY) + print(upd) + assert np.all(upd == [[0, 1, 1, 0, 0, 0], [0, 1, 0, 0, 0, 0]]) + +