From d566375aa6292892bfdc482e4d7af76f2bb14a2a Mon Sep 17 00:00:00 2001 From: Tambet Matiisen Date: Wed, 20 Jun 2018 21:26:44 +0300 Subject: [PATCH] Initial commit. --- imitation/AlphaGoZero model.ipynb | 365 ++++++++++++++++++++ imitation/Clean.ipynb | 469 ++++++++++++++++++++++++++ imitation/Conv model.ipynb | 294 ++++++++++++++++ imitation/Discount.ipynb | 161 +++++++++ imitation/Fully connected model.ipynb | 240 +++++++++++++ imitation/Linear model.ipynb | 262 ++++++++++++++ imitation/collect_simple.py | 87 +++++ imitation/eval_model.py | 122 +++++++ mcts/mcts_agent.py | 234 +++++++++++++ 9 files changed, 2234 insertions(+) create mode 100644 imitation/AlphaGoZero model.ipynb create mode 100644 imitation/Clean.ipynb create mode 100644 imitation/Conv model.ipynb create mode 100644 imitation/Discount.ipynb create mode 100644 imitation/Fully connected model.ipynb create mode 100644 imitation/Linear model.ipynb create mode 100644 imitation/collect_simple.py create mode 100644 imitation/eval_model.py create mode 100644 mcts/mcts_agent.py diff --git a/imitation/AlphaGoZero model.ipynb b/imitation/AlphaGoZero model.ipynb new file mode 100644 index 0000000..2df17ee --- /dev/null +++ b/imitation/AlphaGoZero model.ipynb @@ -0,0 +1,365 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from keras.models import Model\n", + "from keras.layers import Input, Conv2D, Flatten, Dense, BatchNormalization, Activation, add\n", + "import tensorflow as tf\n", + "import keras.backend as K\n", + "from sklearn.metrics import explained_variance_score, accuracy_score\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# make sure TF does not allocate all memory\n", + "config = tf.ConfigProto()\n", + "config.gpu_options.allow_growth = True\n", + "K.set_session(tf.Session(config=config))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((594333, 11, 11, 18), (594333,), (594333,))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = np.load('simple_600K_disc0.99_cleaned.npz')\n", + "x_train = data['observations']\n", + "p_train = data['actions']\n", + "v_train = data['rewards']\n", + "x_train.shape, p_train.shape, v_train.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((95623, 11, 11, 18), (95623,), (95623,))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = np.load('valid_100K_disc0.99_cleaned.npz')\n", + "x_test = data['observations']\n", + "p_test = data['actions']\n", + "v_test = data['rewards']\n", + "x_test.shape, p_test.shape, v_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "__________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + "input_5 (InputLayer) (None, 11, 11, 18) 0 \n", + "__________________________________________________________________________________________________\n", + "conv2d_14 (Conv2D) (None, 11, 11, 256) 41728 input_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_13 (BatchNo (None, 11, 11, 256) 1024 conv2d_14[0][0] \n", + "__________________________________________________________________________________________________\n", + "activation_11 (Activation) (None, 11, 11, 256) 0 batch_normalization_13[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_15 (Conv2D) (None, 11, 11, 256) 590080 activation_11[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_14 (BatchNo (None, 11, 11, 256) 1024 conv2d_15[0][0] \n", + "__________________________________________________________________________________________________\n", + "activation_12 (Activation) (None, 11, 11, 256) 0 batch_normalization_14[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_16 (Conv2D) (None, 11, 11, 256) 590080 activation_12[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_15 (BatchNo (None, 11, 11, 256) 1024 conv2d_16[0][0] \n", + "__________________________________________________________________________________________________\n", + "add_4 (Add) (None, 11, 11, 256) 0 batch_normalization_15[0][0] \n", + " activation_11[0][0] \n", + "__________________________________________________________________________________________________\n", + "activation_13 (Activation) (None, 11, 11, 256) 0 add_4[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_17 (Conv2D) (None, 11, 11, 256) 590080 activation_13[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_16 (BatchNo (None, 11, 11, 256) 1024 conv2d_17[0][0] \n", + "__________________________________________________________________________________________________\n", + "activation_14 (Activation) (None, 11, 11, 256) 0 batch_normalization_16[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_18 (Conv2D) (None, 11, 11, 256) 590080 activation_14[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_17 (BatchNo (None, 11, 11, 256) 1024 conv2d_18[0][0] \n", + "__________________________________________________________________________________________________\n", + "add_5 (Add) (None, 11, 11, 256) 0 batch_normalization_17[0][0] \n", + " activation_13[0][0] \n", + "__________________________________________________________________________________________________\n", + "activation_15 (Activation) (None, 11, 11, 256) 0 add_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_19 (Conv2D) (None, 11, 11, 256) 590080 activation_15[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_18 (BatchNo (None, 11, 11, 256) 1024 conv2d_19[0][0] \n", + "__________________________________________________________________________________________________\n", + "activation_16 (Activation) (None, 11, 11, 256) 0 batch_normalization_18[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_20 (Conv2D) (None, 11, 11, 256) 590080 activation_16[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_19 (BatchNo (None, 11, 11, 256) 1024 conv2d_20[0][0] \n", + "__________________________________________________________________________________________________\n", + "add_6 (Add) (None, 11, 11, 256) 0 batch_normalization_19[0][0] \n", + " activation_15[0][0] \n", + "__________________________________________________________________________________________________\n", + "activation_17 (Activation) (None, 11, 11, 256) 0 add_6[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_22 (Conv2D) (None, 11, 11, 1) 257 activation_17[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_21 (BatchNo (None, 11, 11, 1) 4 conv2d_22[0][0] \n", + "__________________________________________________________________________________________________\n", + "activation_19 (Activation) (None, 11, 11, 1) 0 batch_normalization_21[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_21 (Conv2D) (None, 11, 11, 256) 590080 activation_17[0][0] \n", + "__________________________________________________________________________________________________\n", + "flatten_3 (Flatten) (None, 121) 0 activation_19[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_20 (BatchNo (None, 11, 11, 256) 1024 conv2d_21[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_1 (Dense) (None, 256) 31232 flatten_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "activation_18 (Activation) (None, 11, 11, 256) 0 batch_normalization_20[0][0] \n", + "__________________________________________________________________________________________________\n", + "activation_20 (Activation) (None, 256) 0 dense_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "flatten_2 (Flatten) (None, 30976) 0 activation_18[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_2 (Dense) (None, 1) 257 activation_20[0][0] \n", + "__________________________________________________________________________________________________\n", + "p (Dense) (None, 6) 185862 flatten_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "v (Activation) (None, 1) 0 dense_2[0][0] \n", + "==================================================================================================\n", + "Total params: 4,398,092\n", + "Trainable params: 4,393,994\n", + "Non-trainable params: 4,098\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "def ConvBlock(mod):\n", + " mod = Conv2D(filters=256, kernel_size=3, strides=1, padding=\"same\")(mod)\n", + " mod = BatchNormalization()(mod)\n", + " mod = Activation('relu')(mod)\n", + " return mod\n", + "\n", + "def ResidualBlock(mod):\n", + " tmp = mod\n", + " mod = ConvBlock(mod)\n", + " mod = Conv2D(filters=256, kernel_size=3, strides=1, padding=\"same\")(mod)\n", + " mod = BatchNormalization()(mod)\n", + " mod = add([mod,tmp])\n", + " mod = Activation('relu')(mod)\n", + " return mod\n", + "\n", + "def PolicyHead(mod):\n", + " mod = Conv2D(filters=2, kernel_size=1, strides=1, padding=\"same\")(mod)\n", + " mod = BatchNormalization()(mod)\n", + " mod = Activation('relu')(mod)\n", + " mod = Flatten()(mod)\n", + " mod = Dense(6, activation='softmax', name='p')(mod)\n", + " return mod\n", + "\n", + "def ValueHead(mod):\n", + " mod = Conv2D(filters=1, kernel_size=1, strides=1, padding=\"same\")(mod)\n", + " mod = BatchNormalization()(mod)\n", + " mod = Activation('relu')(mod)\n", + " mod = Flatten()(mod)\n", + " mod = Dense(256)(mod)\n", + " mod = Activation('relu')(mod)\n", + " mod = Dense(1)(mod)\n", + " mod = Activation('tanh',name='v')(mod)\n", + " return mod\n", + "\n", + "h = x = Input(shape=(11,11,18))\n", + "h = ConvBlock(h)\n", + "for i in range(3):\n", + " h = ResidualBlock(h)\n", + "p = PolicyHead(h)\n", + "v = ValueHead(h)\n", + "model = Model(x, [p, v])\n", + "model.summary()\n", + "model.compile(optimizer='adam', loss=['sparse_categorical_crossentropy', 'mse'], loss_weights=[1, 10], metrics={'p': 'accuracy'})" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 594333 samples, validate on 95623 samples\n", + "Epoch 1/10\n", + "594333/594333 [==============================] - 1263s 2ms/step - loss: 2.6285 - p_loss: 1.3796 - v_loss: 0.1249 - p_acc: 0.4511 - val_loss: 2.7722 - val_p_loss: 0.8412 - val_v_loss: 0.1931 - val_p_acc: 0.6122\n", + "Epoch 2/10\n", + "594333/594333 [==============================] - 1259s 2ms/step - loss: 1.4321 - p_loss: 0.7752 - v_loss: 0.0657 - p_acc: 0.6346 - val_loss: 2.8289 - val_p_loss: 0.7384 - val_v_loss: 0.2090 - val_p_acc: 0.6484\n", + "Epoch 3/10\n", + "594333/594333 [==============================] - 1261s 2ms/step - loss: 1.1408 - p_loss: 0.7070 - v_loss: 0.0434 - p_acc: 0.6562 - val_loss: 2.8035 - val_p_loss: 0.7012 - val_v_loss: 0.2102 - val_p_acc: 0.6604\n", + "Epoch 4/10\n", + "594333/594333 [==============================] - 1263s 2ms/step - loss: 0.9821 - p_loss: 0.6719 - v_loss: 0.0310 - p_acc: 0.6694 - val_loss: 2.7501 - val_p_loss: 0.6809 - val_v_loss: 0.2069 - val_p_acc: 0.6660\n", + "Epoch 5/10\n", + "594333/594333 [==============================] - 1271s 2ms/step - loss: 0.8873 - p_loss: 0.6493 - v_loss: 0.0238 - p_acc: 0.6804 - val_loss: 2.8450 - val_p_loss: 0.6613 - val_v_loss: 0.2184 - val_p_acc: 0.6732\n", + "Epoch 6/10\n", + "594333/594333 [==============================] - 1281s 2ms/step - loss: 0.8250 - p_loss: 0.6334 - v_loss: 0.0192 - p_acc: 0.6882 - val_loss: 2.7277 - val_p_loss: 0.6603 - val_v_loss: 0.2067 - val_p_acc: 0.6741\n", + "Epoch 7/10\n", + "594333/594333 [==============================] - 1272s 2ms/step - loss: 0.7777 - p_loss: 0.6178 - v_loss: 0.0160 - p_acc: 0.6989 - val_loss: 2.7051 - val_p_loss: 0.6595 - val_v_loss: 0.2046 - val_p_acc: 0.6726\n", + "Epoch 8/10\n", + "594333/594333 [==============================] - 1259s 2ms/step - loss: 0.7405 - p_loss: 0.6021 - v_loss: 0.0138 - p_acc: 0.7090 - val_loss: 2.6416 - val_p_loss: 0.6597 - val_v_loss: 0.1982 - val_p_acc: 0.6717\n", + "Epoch 9/10\n", + "594333/594333 [==============================] - 1263s 2ms/step - loss: 0.7073 - p_loss: 0.5847 - v_loss: 0.0123 - p_acc: 0.7208 - val_loss: 2.6725 - val_p_loss: 0.6739 - val_v_loss: 0.1999 - val_p_acc: 0.6705\n", + "Epoch 10/10\n", + "594333/594333 [==============================] - 1272s 2ms/step - loss: 0.6754 - p_loss: 0.5653 - v_loss: 0.0110 - p_acc: 0.7338 - val_loss: 2.6304 - val_p_loss: 0.6822 - val_v_loss: 0.1948 - val_p_acc: 0.6683\n" + ] + } + ], + "source": [ + "history = model.fit(x_train, [p_train, v_train], batch_size=128, epochs=10, validation_data=(x_test, [p_test, v_test]))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5,1,'Value MSE')" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(16,6))\n", + "plt.subplot(1, 2, 1)\n", + "plt.plot(history.history['p_acc'])\n", + "plt.plot(history.history['val_p_acc'])\n", + "plt.legend(['Train', 'Validation'])\n", + "plt.title(\"Action prediction accuracy\")\n", + "plt.subplot(1, 2, 2)\n", + "plt.plot(history.history['v_loss'])\n", + "plt.plot(history.history['val_v_loss'])\n", + "plt.legend(['Train', 'Validation'])\n", + "plt.title(\"Value MSE\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy train: 0.7605870782877613\n", + "Accuracy test: 0.6682701860431067\n", + "Explained variance train: 0.9296615093133133\n", + "Explained variance test: -0.278033390090739\n" + ] + } + ], + "source": [ + "p_train_pred, v_train_pred = model.predict(x_train)\n", + "p_test_pred, v_test_pred = model.predict(x_test)\n", + "act_train_pred = np.argmax(p_train_pred, axis=1)\n", + "act_test_pred = np.argmax(p_test_pred, axis=1)\n", + "print(\"Accuracy train:\", accuracy_score(p_train, act_train_pred))\n", + "print(\"Accuracy test:\", accuracy_score(p_test, act_test_pred))\n", + "print(\"Explained variance train:\", explained_variance_score(v_train, v_train_pred))\n", + "print(\"Explained variance test:\", explained_variance_score(v_test, v_test_pred))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "model.save('AGZ.h5')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (pommer)", + "language": "python", + "name": "pommer" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/imitation/Clean.ipynb b/imitation/Clean.ipynb new file mode 100644 index 0000000..7d16cde --- /dev/null +++ b/imitation/Clean.ipynb @@ -0,0 +1,469 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "data = np.load(\"valid_100K_disc0.99.npz\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(108262, 11, 11, 18)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_observations = data['observations']\n", + "all_observations.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(108262,)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_actions = data['actions']\n", + "all_actions.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(108262,)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_rewards = data['rewards']\n", + "all_rewards.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "15509" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "obs_stuck = np.all(all_observations[1:] == all_observations[:-1], axis=(1,2,3))\n", + "sum(obs_stuck)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "13419" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "obs_stuck2 = np.logical_and(obs_stuck[:-1], obs_stuck[1:])\n", + "sum(obs_stuck2)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "12657" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "obs_stuck3 = np.logical_and(obs_stuck[:-2], np.logical_and(obs_stuck[1:-1], obs_stuck[2:]))\n", + "sum(obs_stuck3)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "12329" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "obs_stuck4 = np.logical_and(obs_stuck[:-3], np.logical_and(obs_stuck[1:-2], np.logical_and(obs_stuck[2:-1], obs_stuck[3:])))\n", + "sum(obs_stuck4)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "52496" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "act_stuck2 = (all_actions[:-1] == all_actions[1:])\n", + "sum(act_stuck2)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "13394" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stuck2 = np.logical_and(obs_stuck2, act_stuck2[:-1])\n", + "sum(stuck2)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "36767" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "act_stuck3 = np.logical_and(all_actions[:-2] == all_actions[1:-1], all_actions[:-2] == all_actions[2:])\n", + "sum(act_stuck3)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "12639" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stuck3 = np.logical_and(obs_stuck3, act_stuck3[:-1])\n", + "sum(stuck3)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "30329" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "act_stuck4 = \\\n", + " np.logical_and(\n", + " np.logical_and(\n", + " all_actions[:-3] == all_actions[1:-2], \n", + " all_actions[:-3] == all_actions[2:-1]), \n", + " all_actions[:-3] == all_actions[3:])\n", + "sum(act_stuck4)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "12315" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stuck4 = np.logical_and(obs_stuck4, act_stuck4[:-1])\n", + "sum(stuck4)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "ind = np.where(stuck3)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([4851, 2835, 1355, 2012, 1586, 0])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.histogram(all_actions[ind], bins=range(7))[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([26224, 20963, 18259, 19520, 19438, 3858])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.histogram(all_actions, bins=range(7))[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(95623,)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cleaned_actions = np.delete(all_actions, ind, axis=0)\n", + "cleaned_actions.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([21373, 18128, 16904, 17508, 17852, 3858])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.histogram(cleaned_actions, bins=range(7))[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(95623, 11, 11, 18)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cleaned_observations = np.delete(all_observations, ind, axis=0)\n", + "cleaned_observations.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(95623,)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cleaned_rewards = np.delete(all_rewards, ind, axis=0)\n", + "cleaned_rewards.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "np.savez_compressed(\"valid_100K_disc0.99_cleaned.npz\", observations=cleaned_observations, actions=cleaned_actions, rewards=cleaned_rewards)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/imitation/Conv model.ipynb b/imitation/Conv model.ipynb new file mode 100644 index 0000000..4c41636 --- /dev/null +++ b/imitation/Conv model.ipynb @@ -0,0 +1,294 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from keras.models import Model, load_model\n", + "from keras.layers import Input, Conv2D, Flatten, Dense\n", + "from keras.callbacks import ModelCheckpoint, EarlyStopping\n", + "import tensorflow as tf\n", + "import keras.backend as K\n", + "from sklearn.metrics import explained_variance_score, accuracy_score\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# make sure TF does not allocate all memory\n", + "config = tf.ConfigProto()\n", + "config.gpu_options.allow_growth = True\n", + "K.set_session(tf.Session(config=config))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((594333, 11, 11, 18), (594333,), (594333,))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = np.load('simple_600K_disc0.99_cleaned.npz')\n", + "x_train = data['observations']\n", + "p_train = data['actions']\n", + "v_train = data['rewards']\n", + "x_train.shape, p_train.shape, v_train.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((95623, 11, 11, 18), (95623,), (95623,))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = np.load('valid_100K_disc0.99_cleaned.npz')\n", + "x_test = data['observations']\n", + "p_test = data['actions']\n", + "v_test = data['rewards']\n", + "x_test.shape, p_test.shape, v_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "__________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + "input_15 (InputLayer) (None, 11, 11, 18) 0 \n", + "__________________________________________________________________________________________________\n", + "conv2d_43 (Conv2D) (None, 11, 11, 256) 41728 input_15[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_44 (Conv2D) (None, 11, 11, 256) 590080 conv2d_43[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2d_45 (Conv2D) (None, 11, 11, 256) 590080 conv2d_44[0][0] \n", + "__________________________________________________________________________________________________\n", + "flatten_15 (Flatten) (None, 30976) 0 conv2d_45[0][0] \n", + "__________________________________________________________________________________________________\n", + "p (Dense) (None, 6) 185862 flatten_15[0][0] \n", + "__________________________________________________________________________________________________\n", + "v (Dense) (None, 1) 30977 flatten_15[0][0] \n", + "==================================================================================================\n", + "Total params: 1,438,727\n", + "Trainable params: 1,438,727\n", + "Non-trainable params: 0\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "c = x = Input(shape=(11,11,18))\n", + "c = Conv2D(filters=256, kernel_size=3, strides=1, activation=\"relu\", padding=\"same\")(c)\n", + "c = Conv2D(filters=256, kernel_size=3, strides=1, activation=\"relu\", padding=\"same\")(c)\n", + "c = Conv2D(filters=256, kernel_size=3, strides=1, activation=\"relu\", padding=\"same\")(c)\n", + "h = Flatten()(c)\n", + "#h = Dense(128, activation='relu')(h)\n", + "p = Dense(6, activation=\"softmax\", name='p')(h)\n", + "#h = Dense(128, activation='relu')(h)\n", + "v = Dense(1, activation=\"tanh\", name='v')(h)\n", + "model = Model(x, [p, v])\n", + "model.summary()\n", + "model.compile(optimizer='adam', loss=['sparse_categorical_crossentropy', 'mse'], loss_weights=[1, 0], metrics={'p': 'accuracy'})" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [], + "source": [ + "callbacks = [\n", + " ModelCheckpoint('conv.h5', monitor='val_p_acc', verbose=1, save_best_only=True, mode='max'),\n", + " EarlyStopping(monitor='val_p_acc', min_delta=0.001, patience=5, verbose=1, mode='max')\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 594333 samples, validate on 95623 samples\n", + "Epoch 1/10\n", + "594333/594333 [==============================] - 188s 316us/step - loss: 0.9410 - p_loss: 0.9410 - v_loss: 0.1741 - p_acc: 0.5675 - val_loss: 0.7122 - val_p_loss: 0.7122 - val_v_loss: 0.1774 - val_p_acc: 0.6569\n", + "\n", + "Epoch 00001: val_p_acc improved from -inf to 0.65693, saving model to conv.h5\n", + "Epoch 2/10\n", + "594333/594333 [==============================] - 185s 312us/step - loss: 0.6879 - p_loss: 0.6879 - v_loss: 0.1767 - p_acc: 0.6606 - val_loss: 0.6690 - val_p_loss: 0.6690 - val_v_loss: 0.1769 - val_p_acc: 0.6687\n", + "\n", + "Epoch 00002: val_p_acc improved from 0.65693 to 0.66874, saving model to conv.h5\n", + "Epoch 3/10\n", + "594333/594333 [==============================] - 185s 311us/step - loss: 0.6586 - p_loss: 0.6586 - v_loss: 0.1778 - p_acc: 0.6723 - val_loss: 0.6630 - val_p_loss: 0.6630 - val_v_loss: 0.1778 - val_p_acc: 0.6714\n", + "\n", + "Epoch 00003: val_p_acc improved from 0.66874 to 0.67145, saving model to conv.h5\n", + "Epoch 4/10\n", + "594333/594333 [==============================] - 185s 312us/step - loss: 0.6431 - p_loss: 0.6431 - v_loss: 0.1787 - p_acc: 0.6796 - val_loss: 0.6505 - val_p_loss: 0.6505 - val_v_loss: 0.1784 - val_p_acc: 0.6752\n", + "\n", + "Epoch 00004: val_p_acc improved from 0.67145 to 0.67515, saving model to conv.h5\n", + "Epoch 5/10\n", + "594333/594333 [==============================] - 185s 312us/step - loss: 0.6323 - p_loss: 0.6323 - v_loss: 0.1797 - p_acc: 0.6853 - val_loss: 0.6551 - val_p_loss: 0.6551 - val_v_loss: 0.1795 - val_p_acc: 0.6752\n", + "\n", + "Epoch 00005: val_p_acc improved from 0.67515 to 0.67517, saving model to conv.h5\n", + "Epoch 6/10\n", + "594333/594333 [==============================] - 184s 310us/step - loss: 0.6224 - p_loss: 0.6224 - v_loss: 0.1802 - p_acc: 0.6916 - val_loss: 0.6527 - val_p_loss: 0.6527 - val_v_loss: 0.1808 - val_p_acc: 0.6764\n", + "\n", + "Epoch 00006: val_p_acc improved from 0.67517 to 0.67637, saving model to conv.h5\n", + "Epoch 7/10\n", + "594333/594333 [==============================] - 184s 309us/step - loss: 0.6127 - p_loss: 0.6127 - v_loss: 0.1815 - p_acc: 0.6980 - val_loss: 0.6527 - val_p_loss: 0.6527 - val_v_loss: 0.1793 - val_p_acc: 0.6729\n", + "\n", + "Epoch 00007: val_p_acc did not improve from 0.67637\n", + "Epoch 8/10\n", + "594333/594333 [==============================] - 184s 309us/step - loss: 0.6019 - p_loss: 0.6019 - v_loss: 0.1821 - p_acc: 0.7047 - val_loss: 0.6612 - val_p_loss: 0.6612 - val_v_loss: 0.1815 - val_p_acc: 0.6749\n", + "\n", + "Epoch 00008: val_p_acc did not improve from 0.67637\n", + "Epoch 9/10\n", + "594333/594333 [==============================] - 184s 309us/step - loss: 0.5900 - p_loss: 0.5900 - v_loss: 0.1830 - p_acc: 0.7123 - val_loss: 0.6723 - val_p_loss: 0.6723 - val_v_loss: 0.1820 - val_p_acc: 0.6716\n", + "\n", + "Epoch 00009: val_p_acc did not improve from 0.67637\n", + "Epoch 10/10\n", + "594333/594333 [==============================] - 184s 309us/step - loss: 0.5767 - p_loss: 0.5767 - v_loss: 0.1854 - p_acc: 0.7196 - val_loss: 0.6801 - val_p_loss: 0.6801 - val_v_loss: 0.1843 - val_p_acc: 0.6700\n", + "\n", + "Epoch 00010: val_p_acc did not improve from 0.67637\n" + ] + } + ], + "source": [ + "history = model.fit(x_train, [p_train, v_train], batch_size=512, epochs=10, validation_data=(x_test, [p_test, v_test]), callbacks=callbacks)" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5,1,'Value MSE')" + ] + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(16,6))\n", + "plt.subplot(1, 2, 1)\n", + "plt.plot(history.history['p_acc'])\n", + "plt.plot(history.history['val_p_acc'])\n", + "plt.legend(['Train', 'Validation'])\n", + "plt.title(\"Action prediction accuracy\")\n", + "plt.subplot(1, 2, 2)\n", + "plt.plot(history.history['v_loss'])\n", + "plt.plot(history.history['val_v_loss'])\n", + "plt.legend(['Train', 'Validation'])\n", + "plt.title(\"Value MSE\")" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy train: 0.7048102662985228\n", + "Accuracy test: 0.6763749307175052\n", + "Explained variance train: -0.04325305576853089\n", + "Explained variance test: -0.04068448877452058\n" + ] + } + ], + "source": [ + "model = load_model('conv.h5')\n", + "p_train_pred, v_train_pred = model.predict(x_train, batch_size=2048)\n", + "p_test_pred, v_test_pred = model.predict(x_test, batch_size=2048)\n", + "act_train_pred = np.argmax(p_train_pred, axis=1)\n", + "act_test_pred = np.argmax(p_test_pred, axis=1)\n", + "print(\"Accuracy train:\", accuracy_score(p_train, act_train_pred))\n", + "print(\"Accuracy test:\", accuracy_score(p_test, act_test_pred))\n", + "print(\"Explained variance train:\", explained_variance_score(v_train, v_train_pred))\n", + "print(\"Explained variance test:\", explained_variance_score(v_test, v_test_pred))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/imitation/Discount.ipynb b/imitation/Discount.ipynb new file mode 100644 index 0000000..0d14691 --- /dev/null +++ b/imitation/Discount.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "data = np.load(\"valid_100K.npz\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(108262, 11, 11, 18)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "observations = data['observations']\n", + "observations.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(108262,)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "actions = data['actions']\n", + "actions.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(108262,)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rewards = data['rewards']\n", + "rewards.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "discount = 0.99\n", + "disc_rewards = []\n", + "for r in reversed(rewards):\n", + " if r != 0:\n", + " rew = r\n", + " disc_rewards.insert(0, rew)\n", + " rew *= discount" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "np.savez_compressed(\"valid_100K_disc0.99.npz\", observations=observations, actions=actions, rewards=disc_rewards)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "discount = 1\n", + "disc_rewards = []\n", + "for r in reversed(rewards):\n", + " if r != 0:\n", + " rew = r\n", + " disc_rewards.insert(0, rew)\n", + " rew *= discount" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "np.savez_compressed(\"valid_100K_disc1.npz\", observations=observations, actions=actions, rewards=disc_rewards)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/imitation/Fully connected model.ipynb b/imitation/Fully connected model.ipynb new file mode 100644 index 0000000..1f904ef --- /dev/null +++ b/imitation/Fully connected model.ipynb @@ -0,0 +1,240 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from keras.models import Model\n", + "from keras.layers import Input, Flatten, Dense\n", + "import tensorflow as tf\n", + "import keras.backend as K\n", + "from sklearn.metrics import explained_variance_score\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# make sure TF does not allocate all memory\n", + "config = tf.ConfigProto()\n", + "config.gpu_options.allow_growth = True\n", + "K.set_session(tf.Session(config=config))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((594333, 11, 11, 18), (594333,), (594333,))" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = np.load('simple_600K_disc0.99_cleaned.npz')\n", + "x_train = data['observations']\n", + "p_train = data['actions']\n", + "v_train = data['rewards']\n", + "x_train.shape, p_train.shape, v_train.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((95623, 11, 11, 18), (95623,), (95623,))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = np.load('valid_100K_disc0.99_cleaned.npz')\n", + "x_test = data['observations']\n", + "p_test = data['actions']\n", + "v_test = data['rewards']\n", + "x_test.shape, p_test.shape, v_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "__________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + "input_1 (InputLayer) (None, 11, 11, 18) 0 \n", + "__________________________________________________________________________________________________\n", + "flatten_1 (Flatten) (None, 2178) 0 input_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_1 (Dense) (None, 128) 278912 flatten_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_2 (Dense) (None, 128) 16512 dense_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_3 (Dense) (None, 128) 16512 dense_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "p (Dense) (None, 6) 774 dense_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "v (Dense) (None, 1) 129 dense_3[0][0] \n", + "==================================================================================================\n", + "Total params: 312,839\n", + "Trainable params: 312,839\n", + "Non-trainable params: 0\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "x = Input(shape=(11,11,18))\n", + "h = Flatten()(x)\n", + "h = Dense(128)(h)\n", + "h = Dense(128)(h)\n", + "p = Dense(6, activation=\"softmax\", name='p')(h)\n", + "h = Dense(128)(h)\n", + "v = Dense(1, activation=\"tanh\", name='v')(h)\n", + "model = Model(x, [p, v])\n", + "model.summary()\n", + "model.compile(optimizer='adam', loss=['sparse_categorical_crossentropy', 'mse'], metrics={'p': 'accuracy'})" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 594333 samples, validate on 95623 samples\n", + "Epoch 1/10\n", + "594333/594333 [==============================] - 23s 39us/step - loss: 1.6708 - p_loss: 1.5147 - v_loss: 0.1561 - p_acc: 0.3642 - val_loss: 1.6407 - val_p_loss: 1.4869 - val_v_loss: 0.1538 - val_p_acc: 0.3847\n", + "Epoch 2/10\n", + "594333/594333 [==============================] - 22s 38us/step - loss: 1.6097 - p_loss: 1.4652 - v_loss: 0.1444 - p_acc: 0.3871 - val_loss: 1.6128 - val_p_loss: 1.4660 - val_v_loss: 0.1467 - val_p_acc: 0.3838\n", + "Epoch 3/10\n", + "594333/594333 [==============================] - 22s 38us/step - loss: 1.5957 - p_loss: 1.4522 - v_loss: 0.1435 - p_acc: 0.3922 - val_loss: 1.6149 - val_p_loss: 1.4664 - val_v_loss: 0.1485 - val_p_acc: 0.3915\n", + "Epoch 4/10\n", + "594333/594333 [==============================] - 22s 38us/step - loss: 1.5872 - p_loss: 1.4444 - v_loss: 0.1428 - p_acc: 0.3948 - val_loss: 1.6454 - val_p_loss: 1.4987 - val_v_loss: 0.1468 - val_p_acc: 0.3796\n", + "Epoch 5/10\n", + "594333/594333 [==============================] - 22s 38us/step - loss: 1.5820 - p_loss: 1.4396 - v_loss: 0.1425 - p_acc: 0.3963 - val_loss: 1.5944 - val_p_loss: 1.4486 - val_v_loss: 0.1458 - val_p_acc: 0.3890\n", + "Epoch 6/10\n", + "594333/594333 [==============================] - 23s 38us/step - loss: 1.5779 - p_loss: 1.4358 - v_loss: 0.1421 - p_acc: 0.3978 - val_loss: 1.6073 - val_p_loss: 1.4606 - val_v_loss: 0.1467 - val_p_acc: 0.3887\n", + "Epoch 7/10\n", + "594333/594333 [==============================] - 23s 38us/step - loss: 1.5755 - p_loss: 1.4336 - v_loss: 0.1419 - p_acc: 0.3984 - val_loss: 1.5918 - val_p_loss: 1.4465 - val_v_loss: 0.1454 - val_p_acc: 0.3963\n", + "Epoch 8/10\n", + "594333/594333 [==============================] - 23s 38us/step - loss: 1.5729 - p_loss: 1.4312 - v_loss: 0.1417 - p_acc: 0.3997 - val_loss: 1.5935 - val_p_loss: 1.4477 - val_v_loss: 0.1458 - val_p_acc: 0.3897\n", + "Epoch 9/10\n", + "594333/594333 [==============================] - 23s 38us/step - loss: 1.5711 - p_loss: 1.4295 - v_loss: 0.1415 - p_acc: 0.4003 - val_loss: 1.6120 - val_p_loss: 1.4627 - val_v_loss: 0.1493 - val_p_acc: 0.3903\n", + "Epoch 10/10\n", + "594333/594333 [==============================] - 23s 38us/step - loss: 1.5708 - p_loss: 1.4293 - v_loss: 0.1415 - p_acc: 0.3997 - val_loss: 1.5930 - val_p_loss: 1.4465 - val_v_loss: 0.1466 - val_p_acc: 0.3916\n" + ] + } + ], + "source": [ + "history = model.fit(x_train, [p_train, v_train], batch_size=128, epochs=10, validation_data=(x_test, [p_test, v_test]))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5,1,'Value MSE')" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(16,6))\n", + "plt.subplot(1, 2, 1)\n", + "plt.plot(history.history['p_acc'])\n", + "plt.plot(history.history['val_p_acc'])\n", + "plt.legend(['Train', 'Validation'])\n", + "plt.title(\"Action prediction accuracy\")\n", + "plt.subplot(1, 2, 2)\n", + "plt.plot(history.history['v_loss'])\n", + "plt.plot(history.history['val_v_loss'])\n", + "plt.legend(['Train', 'Validation'])\n", + "plt.title(\"Value MSE\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "model.save('dense.h5')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/imitation/Linear model.ipynb b/imitation/Linear model.ipynb new file mode 100644 index 0000000..c7a9a0d --- /dev/null +++ b/imitation/Linear model.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/tambet/.conda/envs/pommer/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", + " from ._conv import register_converters as _register_converters\n", + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "from keras.models import Model\n", + "from keras.layers import Input, Flatten, Dense\n", + "import tensorflow as tf\n", + "import keras.backend as K\n", + "from sklearn.metrics import explained_variance_score\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# make sure TF does not allocate all memory\n", + "config = tf.ConfigProto()\n", + "config.gpu_options.allow_growth = True\n", + "K.set_session(tf.Session(config=config))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((594333, 11, 11, 18), (594333,), (594333,))" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = np.load('simple_600K_disc0.99_cleaned.npz')\n", + "x_train = data['observations']\n", + "p_train = data['actions']\n", + "v_train = data['rewards']\n", + "x_train.shape, p_train.shape, v_train.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((95623, 11, 11, 18), (95623,), (95623,))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = np.load('valid_100K_disc0.99_cleaned.npz')\n", + "x_test = data['observations']\n", + "p_test = data['actions']\n", + "v_test = data['rewards']\n", + "x_test.shape, p_test.shape, v_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "__________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + "input_1 (InputLayer) (None, 11, 11, 18) 0 \n", + "__________________________________________________________________________________________________\n", + "flatten_1 (Flatten) (None, 2178) 0 input_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "p (Dense) (None, 6) 13074 flatten_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "v (Dense) (None, 1) 2179 flatten_1[0][0] \n", + "==================================================================================================\n", + "Total params: 15,253\n", + "Trainable params: 15,253\n", + "Non-trainable params: 0\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "x = Input(shape=(11,11,18))\n", + "h = Flatten()(x)\n", + "p = Dense(6, activation=\"softmax\", name='p')(h)\n", + "v = Dense(1, activation=\"tanh\", name='v')(h)\n", + "model = Model(x, [p, v])\n", + "model.summary()\n", + "model.compile(optimizer='adam', loss=['sparse_categorical_crossentropy', 'mse'], loss_weights=[1, 10], metrics={'p': 'accuracy'})" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 594333 samples, validate on 95623 samples\n", + "Epoch 1/10\n", + "594333/594333 [==============================] - 17s 29us/step - loss: 3.0279 - p_loss: 1.5010 - v_loss: 0.1527 - p_acc: 0.3666 - val_loss: 3.0192 - val_p_loss: 1.4721 - val_v_loss: 0.1547 - val_p_acc: 0.3762\n", + "Epoch 2/10\n", + "594333/594333 [==============================] - 16s 27us/step - loss: 2.9657 - p_loss: 1.4512 - v_loss: 0.1514 - p_acc: 0.3903 - val_loss: 3.0443 - val_p_loss: 1.5404 - val_v_loss: 0.1504 - val_p_acc: 0.3683\n", + "Epoch 3/10\n", + "594333/594333 [==============================] - 16s 27us/step - loss: 2.9579 - p_loss: 1.4466 - v_loss: 0.1511 - p_acc: 0.3921 - val_loss: 3.5232 - val_p_loss: 1.5417 - val_v_loss: 0.1981 - val_p_acc: 0.3512\n", + "Epoch 4/10\n", + "594333/594333 [==============================] - 16s 27us/step - loss: 2.9666 - p_loss: 1.4457 - v_loss: 0.1521 - p_acc: 0.3922 - val_loss: 3.3368 - val_p_loss: 1.4831 - val_v_loss: 0.1854 - val_p_acc: 0.3801\n", + "Epoch 5/10\n", + "594333/594333 [==============================] - 16s 27us/step - loss: 2.9504 - p_loss: 1.4443 - v_loss: 0.1506 - p_acc: 0.3935 - val_loss: 3.0418 - val_p_loss: 1.4385 - val_v_loss: 0.1603 - val_p_acc: 0.3945\n", + "Epoch 6/10\n", + "594333/594333 [==============================] - 16s 27us/step - loss: 2.9540 - p_loss: 1.4445 - v_loss: 0.1510 - p_acc: 0.3931 - val_loss: 2.9814 - val_p_loss: 1.4604 - val_v_loss: 0.1521 - val_p_acc: 0.3867\n", + "Epoch 7/10\n", + "594333/594333 [==============================] - 17s 28us/step - loss: 2.9521 - p_loss: 1.4444 - v_loss: 0.1508 - p_acc: 0.3936 - val_loss: 3.0240 - val_p_loss: 1.4668 - val_v_loss: 0.1557 - val_p_acc: 0.3847\n", + "Epoch 8/10\n", + "594333/594333 [==============================] - 16s 27us/step - loss: 2.9516 - p_loss: 1.4428 - v_loss: 0.1509 - p_acc: 0.3936 - val_loss: 2.9462 - val_p_loss: 1.4515 - val_v_loss: 0.1495 - val_p_acc: 0.3905\n", + "Epoch 9/10\n", + "594333/594333 [==============================] - 16s 28us/step - loss: 2.9529 - p_loss: 1.4439 - v_loss: 0.1509 - p_acc: 0.3931 - val_loss: 3.0163 - val_p_loss: 1.4730 - val_v_loss: 0.1543 - val_p_acc: 0.3770\n", + "Epoch 10/10\n", + "594333/594333 [==============================] - 16s 28us/step - loss: 2.9568 - p_loss: 1.4442 - v_loss: 0.1513 - p_acc: 0.3934 - val_loss: 3.1297 - val_p_loss: 1.5114 - val_v_loss: 0.1618 - val_p_acc: 0.3631\n" + ] + } + ], + "source": [ + "history = model.fit(x_train, [p_train, v_train], batch_size=128, epochs=10, validation_data=(x_test, [p_test, v_test]))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5,1,'Value MSE')" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(16,6))\n", + "plt.subplot(1, 2, 1)\n", + "plt.plot(history.history['p_acc'])\n", + "plt.plot(history.history['val_p_acc'])\n", + "plt.legend(['Train', 'Validation'])\n", + "plt.title(\"Action prediction accuracy\")\n", + "plt.subplot(1, 2, 2)\n", + "plt.plot(history.history['v_loss'])\n", + "plt.plot(history.history['val_v_loss'])\n", + "plt.legend(['Train', 'Validation'])\n", + "plt.title(\"Value MSE\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Explained variance train: 0.012845516329725748\n", + "Explained variance test: -0.005572327426227375\n" + ] + } + ], + "source": [ + "_, v_train_pred = model.predict(x_train)\n", + "_, v_test_pred = model.predict(x_test)\n", + "print(\"Explained variance train:\", explained_variance_score(v_train, v_train_pred))\n", + "print(\"Explained variance test:\", explained_variance_score(v_test, v_test_pred))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "model.save('linear.h5')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (pommer)", + "language": "python", + "name": "pommer" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/imitation/collect_simple.py b/imitation/collect_simple.py new file mode 100644 index 0000000..13c68ed --- /dev/null +++ b/imitation/collect_simple.py @@ -0,0 +1,87 @@ +import pommerman +from pommerman import agents +import numpy as np +import argparse +from copy import deepcopy + + +def featurize(obs): + # TODO: history of n moves? + board = obs['board'] + + # convert board items into bitmaps + maps = [board == i for i in range(10)] + maps.append(obs['bomb_blast_strength']) + maps.append(obs['bomb_life']) + + # duplicate ammo, blast_strength and can_kick over entire map + maps.append(np.full(board.shape, obs['ammo'])) + maps.append(np.full(board.shape, obs['blast_strength'])) + maps.append(np.full(board.shape, obs['can_kick'])) + + # add my position as bitmap + position = np.zeros(board.shape) + position[obs['position']] = 1 + maps.append(position) + + # add teammate + if obs['teammate'] is not None: + maps.append(board == obs['teammate'].value) + else: + maps.append(np.zeros(board.shape)) + + # add enemies + enemies = [board == e.value for e in obs['enemies']] + maps.append(np.any(enemies, axis=0)) + + return np.stack(maps, axis=2) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--num_episodes', type=int, default=1000) + parser.add_argument('--render', action="store_true", default=False) + parser.add_argument('out_file') + args = parser.parse_args() + + # Create a set of agents (exactly four) + agent_list = [ + agents.SimpleAgent(), + agents.SimpleAgent(), + agents.SimpleAgent(), + agents.SimpleAgent(), + ] + + # Make the "Free-For-All" environment using the agent list + env = pommerman.make('PommeFFACompetition-v0', agent_list) + + observations = [[], [], [], []] + actions = [[], [], [], []] + rewards = [[], [], [], []] + + # Run the episodes just like OpenAI Gym + for i in range(args.num_episodes): + obs = env.reset() + done = False + reward = [0, 0, 0, 0] + t = 0 + while not done: + if args.render: + env.render() + action = env.act(obs) + new_obs, new_reward, done, info = env.step(action) + for j in range(4): + if reward[j] == 0: + observations[j].append(featurize(obs[j])) + actions[j].append(action[j]) + rewards[j].append(new_reward[j]) + obs = deepcopy(new_obs) + reward = deepcopy(new_reward) + t += 1 + print("Episode:", i + 1, "Max length:", t, "Rewards:", reward) + env.close() + + np.savez_compressed(args.out_file, + observations=sum(observations, []), + actions=sum(actions, []), + rewards=sum(rewards, [])) diff --git a/imitation/eval_model.py b/imitation/eval_model.py new file mode 100644 index 0000000..e997b9f --- /dev/null +++ b/imitation/eval_model.py @@ -0,0 +1,122 @@ +import pommerman +from pommerman import agents +import numpy as np +import time +from keras.models import load_model +import keras.backend as K +import tensorflow as tf +import argparse + + +def featurize(obs): + # TODO: history of n moves? + board = obs['board'] + + # convert board items into bitmaps + maps = [board == i for i in range(10)] + maps.append(obs['bomb_blast_strength']) + maps.append(obs['bomb_life']) + + # duplicate ammo, blast_strength and can_kick over entire map + maps.append(np.full(board.shape, obs['ammo'])) + maps.append(np.full(board.shape, obs['blast_strength'])) + maps.append(np.full(board.shape, obs['can_kick'])) + + # add my position as bitmap + position = np.zeros(board.shape) + position[obs['position']] = 1 + maps.append(position) + + # add teammate + if obs['teammate'] is not None: + maps.append(board == obs['teammate'].value) + else: + maps.append(np.zeros(board.shape)) + + # add enemies + enemies = [board == e.value for e in obs['enemies']] + maps.append(np.any(enemies, axis=0)) + + return np.stack(maps, axis=2) + + +class KerasAgent(agents.BaseAgent): + def __init__(self, model_file): + super().__init__() + self.model = load_model(model_file) + + def act(self, obs, action_space): + feat = featurize(obs) + probs, values = self.model.predict(feat[np.newaxis]) + action = np.argmax(probs[0]) + #print("Action:", action) + return action + + +def eval_model(agent_id, model_file, num_episodes): + # Create a set of agents (exactly four) + agent_list = [ + agents.SimpleAgent(), + agents.SimpleAgent(), + agents.SimpleAgent(), + ] + agent_list.insert(agent_id, KerasAgent(model_file)) + + # Make the "Free-For-All" environment using the agent list + env = pommerman.make('PommeFFACompetition-v0', agent_list) + + rewards = [] + lengths = [] + start_time = time.time() + # Run the episodes just like OpenAI Gym + for i_episode in range(num_episodes): + state = env.reset() + done = False + lens = [None] * 4 + t = 0 + while not done: + if args.render: + env.render() + actions = env.act(state) + state, reward, done, info = env.step(actions) + for j in range(4): + if lens[j] is None and reward[j] != 0: + lens[j] = t + t += 1 + rewards.append(reward) + lengths.append(lens) + print('Episode ', i_episode, "reward:", reward[agent_id], "length:", lens[agent_id]) + elapsed = time.time() - start_time + env.close() + return rewards, lengths, elapsed + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('model_file') + parser.add_argument('--num_episodes', type=int, default=400) + parser.add_argument('--render', action='store_true', default=False) + args = parser.parse_args() + + # make sure TF does not allocate all memory + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + K.set_session(tf.Session(config=config)) + + rewards0, lengths0, elapsed0 = eval_model(0, args.model_file, args.num_episodes // 4) + rewards1, lengths1, elapsed1 = eval_model(1, args.model_file, args.num_episodes // 4) + rewards2, lengths2, elapsed2 = eval_model(2, args.model_file, args.num_episodes // 4) + rewards3, lengths3, elapsed3 = eval_model(3, args.model_file, args.num_episodes // 4) + + rewards = [(r0[0], r1[1], r2[2], r3[3]) for r0, r1, r2, r3 in zip(rewards0, rewards1, rewards2, rewards3)] + lengths = [(l0[0], l1[1], l2[2], l3[3]) for l0, l1, l2, l3 in zip(lengths0, lengths1, lengths2, lengths3)] + + print("Average reward:", np.mean(rewards)) + print("Average length:", np.mean(lengths)) + + print("Average rewards per position:", np.mean(rewards, axis=0)) + print("Average lengths per position:", np.mean(lengths, axis=0)) + + elapsed = elapsed0 + elapsed1 + elapsed2 + elapsed3 + total_timesteps = np.sum(np.max(np.concatenate([lengths0, lengths1, lengths2, lengths3], axis=0), axis=1)) + print("Time per timestep:", elapsed / total_timesteps) diff --git a/mcts/mcts_agent.py b/mcts/mcts_agent.py new file mode 100644 index 0000000..7462b29 --- /dev/null +++ b/mcts/mcts_agent.py @@ -0,0 +1,234 @@ +import argparse +import multiprocessing +from queue import Empty +import numpy as np +import time + +import pommerman +from pommerman.agents import BaseAgent, SimpleAgent +from pommerman import constants + + +NUM_AGENTS = 4 +NUM_ACTIONS = len(constants.Action) +NUM_CHANNELS = 18 + + +def argmax_tiebreaking(Q): + # find the best action with random tie-breaking + idx = np.flatnonzero(np.isclose(Q, np.max(Q))) + assert len(idx) > 0, str(Q) + return np.random.choice(idx) + + +class MCTSNode(object): + def __init__(self, p): + # values for 6 actions + self.Q = np.zeros(NUM_ACTIONS) + self.W = np.zeros(NUM_ACTIONS) + self.N = np.zeros(NUM_ACTIONS, dtype=np.uint32) + assert p.shape == (NUM_ACTIONS,) + self.P = p + + def action(self): + U = args.mcts_c_puct * self.P * np.sqrt(np.sum(self.N)) / (1 + self.N) + return argmax_tiebreaking(self.Q + U) + + def update(self, action, reward): + self.W[action] += reward + self.N[action] += 1 + self.Q[action] = self.W[action] / self.N[action] + + def probs(self, temperature=1): + if temperature == 0: + p = np.zeros(NUM_ACTIONS) + p[argmax_tiebreaking(self.N)] = 1 + return p + else: + Nt = self.N ** (1.0 / temperature) + return Nt / np.sum(Nt) + + +class MCTSAgent(BaseAgent): + def __init__(self, agent_id=0): + super().__init__() + self.agent_id = agent_id + self.env = self.make_env() + self.reset_tree() + + def make_env(self): + agents = [] + for agent_id in range(NUM_AGENTS): + if agent_id == self.agent_id: + agents.append(self) + else: + agents.append(SimpleAgent()) + + return pommerman.make('PommeFFACompetition-v0', agents) + + def reset_tree(self): + self.tree = {} + + def search(self, root, num_iters, temperature=1): + # remember current game state + self.env._init_game_state = root + + for i in range(num_iters): + # restore game state to root node + obs = self.env.reset() + # serialize game state + state = str(self.env.get_json_info()) + + trace = [] + done = False + while not done: + if state in self.tree: + node = self.tree[state] + # choose actions based on Q + U + action = node.action() + trace.append((node, action)) + else: + # use unfiform distribution for probs + probs = np.ones(NUM_ACTIONS) / NUM_ACTIONS + + # use current rewards for values + rewards = self.env._get_rewards() + reward = rewards[self.agent_id] + + # add new node to the tree + self.tree[state] = MCTSNode(probs) + + # stop at leaf node + break + + # ensure we are not called recursively + assert self.env.training_agent == self.agent_id + # make other agents act + actions = self.env.act(obs) + # add my action to list of actions + actions.insert(self.agent_id, action) + # step environment forward + obs, rewards, done, info = self.env.step(actions) + reward = rewards[self.agent_id] + + # fetch next state + state = str(self.env.get_json_info()) + + # update tree nodes with rollout results + for node, action in reversed(trace): + node.update(action, reward) + reward *= args.discount + + # reset env back where we were + self.env.set_json_info() + self.env._init_game_state = None + # return action probabilities + state = str(root) + return self.tree[state].probs(temperature) + + def rollout(self): + # reset search tree in the beginning of each rollout + self.reset_tree() + + # guarantees that we are not called recursively + # and episode ends when this agent dies + self.env.training_agent = self.agent_id + obs = self.env.reset() + + length = 0 + done = False + while not done: + if args.render: + self.env.render() + + root = self.env.get_json_info() + # do Monte-Carlo tree search + pi = self.search(root, args.mcts_iters, args.temperature) + # sample action from probabilities + action = np.random.choice(NUM_ACTIONS, p=pi) + + # ensure we are not called recursively + assert self.env.training_agent == self.agent_id + # make other agents act + actions = self.env.act(obs) + # add my action to list of actions + actions.insert(self.agent_id, action) + # step environment + obs, rewards, done, info = self.env.step(actions) + assert self == self.env._agents[self.agent_id] + length += 1 + print("Agent:", self.agent_id, "Step:", length, "Actions:", [constants.Action(a).name for a in actions], "Probs:", [round(p, 2) for p in pi], "Rewards:", rewards, "Done:", done) + + reward = rewards[self.agent_id] + return length, reward, rewards + + def act(self, obs, action_space): + # TODO + assert False + + +def runner(id, num_episodes, fifo, _args): + # make args accessible to MCTSAgent + global args + args = _args + # make sure agents play at all positions + agent_id = id % NUM_AGENTS + agent = MCTSAgent(agent_id=agent_id) + + for i in range(num_episodes): + # do rollout + start_time = time.time() + length, reward, rewards = agent.rollout() + elapsed = time.time() - start_time + # add data samples to log + fifo.put((length, reward, rewards, agent_id, elapsed)) + + +def logger(fifo): + all_rewards = [] + all_lengths = [] + all_elapsed = [] + for i in range(args.num_episodes): + try: + # wait for a new trajectory + length, reward, rewards, agent_id, elapsed = fifo.get() + except Empty: + # just ignore empty fifos + continue + + print("Episode:", i, "Reward:", reward, "Length:", length, "Rewards:", rewards, "Agent:", agent_id, "Time per step:", elapsed / length) + all_rewards.append(reward) + all_lengths.append(length) + all_elapsed.append(elapsed) + + print("Average reward:", np.mean(all_rewards)) + print("Average length:", np.mean(all_lengths)) + print("Time per timestep:", np.sum(all_elapsed) / np.sum(all_lengths)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--render', action="store_true", default=False) + parser.add_argument('--num_episodes', type=int, default=400) + # runner params + parser.add_argument('--num_runners', type=int, default=4) + parser.add_argument('--max_steps', type=int, default=constants.MAX_STEPS) + # MCTS params + parser.add_argument('--mcts_iters', type=int, default=10) + parser.add_argument('--mcts_c_puct', type=float, default=1.0) + # RL params + parser.add_argument('--discount', type=float, default=0.99) + parser.add_argument('--temperature', type=float, default=0) + args = parser.parse_args() + + # use spawn method for starting subprocesses + ctx = multiprocessing.get_context('spawn') + + # create fifos and processes for all runners + fifo = ctx.Queue() + for i in range(args.num_runners): + process = ctx.Process(target=runner, args=(i, args.num_episodes // args.num_runners, fifo, args)) + process.start() + + # do logging in the main process + logger(fifo)