-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Upload initial version of the public repository
- Loading branch information
0 parents
commit 692d761
Showing
16 changed files
with
504,659 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.pyc | ||
.ipynb_checkpoints |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4438df4a", | ||
"metadata": {}, | ||
"source": [ | ||
"# TITLE - NNs\n", | ||
"This notebook gives small examples how to work with the NN models introduced in this paper." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d2870fb2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"# add the src folder to the python path to import the classes there\n", | ||
"import sys\n", | ||
"sys.path.append(\"./src/\")\n", | ||
"from cognitive_prior_network import CognitivePriorNetwork\n", | ||
"from context_dependant_network import ContextDependantNetwork" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3fd2e055", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"base_features = [\"Ha\", \"pHa\", \"La\", \"Hb\", \"pHb\", \"Lb\", \"LotNumB\",\n", | ||
" \"LotShapeB\", \"Corr\", \"Amb\", \"Block\", \"Feedback\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a3cd0211", | ||
"metadata": {}, | ||
"source": [ | ||
"## Using pretrained models" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "24d16eab", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# initialize the class and then load the weights from a location\n", | ||
"nn_cpc15 = CognitivePriorNetwork()\n", | ||
"nn_cpc15.load(\"models/cpc_bourgin_prior\")\n", | ||
"\n", | ||
"# load a dataset to use the model on\n", | ||
"choices_df = pd.read_csv(\"data/choices13k.csv\")\n", | ||
"\n", | ||
"# prediction just works with predict, when extracting the right features\n", | ||
"nn_cpc15_predictions = nn_cpc15.predict(choices_df[base_features])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "93ad9634", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# these are also the values which were precomputed\n", | ||
"np.max(np.abs(choices_df.cpc15_cog_prior_pred - nn_cpc15_predictions.flatten()))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f42a9504", | ||
"metadata": {}, | ||
"source": [ | ||
"## Training new models\n", | ||
"We show the code now once for training a cognitive prior network model on choices13k, because more models are able to fit the dataset. To fit models on CPC15, more patience and more pretraining is what helped for us. Models without pretraining have not worked at all for us on CPC15.\n", | ||
"\n", | ||
"### Training imports" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "670edc84", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from hyperopt import pyll, hp, STATUS_OK, fmin, tpe, Trials\n", | ||
"import pickle\n", | ||
"import os" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b8f0c35e", | ||
"metadata": {}, | ||
"source": [ | ||
"### Data loading" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "71bd4846", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def split_xy(dataframe):\n", | ||
" y = dataframe[\"Rate\"].values.astype(np.float32)\n", | ||
" X = dataframe[base_features].values.astype(np.float32)\n", | ||
" return X, y" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a9fdf991", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"synth15_df = pd.read_csv(\"data/synth15.csv\")\n", | ||
"X_synth15, y_synth15 = split_xy(synth15_df)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "fd84bc95", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"cpc15_df = pd.read_csv(\"data/cpc15.csv\", index_col=0)\n", | ||
"X_cpc15_train, y_cpc15_train = split_xy(cpc15_df.iloc[:450])\n", | ||
"X_cpc15_test, y_cpc15_test = split_xy(cpc15_df.iloc[450:])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c9736e2c", | ||
"metadata": {}, | ||
"source": [ | ||
"### Pretraining - Hyperparameter Optimization\n", | ||
"Pretrain multiple models in a principled way.\n", | ||
"\n", | ||
"Every set of parameters gets evaluated with 5 different random seeds. All Models and their corresponding validation loss history get saved under `../models/wide_pretraining`, so make sure you have created this folder on your system." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "13f3bf59", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# We optimize pretraining over the batch size during training\n", | ||
"# as well as over two parameters deciding the connectivity and the rate of change of the network\n", | ||
"hyperparams = {\n", | ||
" 'batch_size': hp.uniformint('batch_size', 300, 1500),\n", | ||
" 'epsilon': hp.uniformint('eps', 10, 100),\n", | ||
" 'zeta': hp.uniform('zeta', 0.2, 0.8)\n", | ||
"}\n", | ||
"\n", | ||
"def f(space):\n", | ||
" l = np.inf\n", | ||
" for i in range(5):\n", | ||
" cognitive_model = CognitivePriorNetwork(input_shape=12, batch_size=space['batch_size'],\n", | ||
" epsilon=space['epsilon'], zeta=space['zeta'])\n", | ||
" cognitive_model.X_test = X_cpc15_train\n", | ||
" cognitive_model.y_test = y_cpc15_train\n", | ||
" cognitive_model.fit(X_synth15, y_synth15, verbose=0, epochs=300, patience=300)\n", | ||
" cognitive_model.save('../models/wide_pretraining/bs_%d_eps_%d_zeta_%.4f_iter_%d'%\n", | ||
" (space['batch_size'], space['epsilon'], space['zeta'], i))\n", | ||
" \n", | ||
" loss = np.min(cognitive_model.loss_per_epoch)\n", | ||
" if loss < l:\n", | ||
" l = loss\n", | ||
" print('done searching combination: batch_size: %d, epsilon: %d, zeta %.4f'%\n", | ||
" (space['batch_size'], space['epsilon'], space['zeta']))\n", | ||
" return {'loss': l, 'status': STATUS_OK}\n", | ||
"\n", | ||
"# these optimizations can be interrupted and continue\n", | ||
"# just pickle the trials object if the current run was interrupted\n", | ||
"if os.path.isfile(\"models/cpc15_prior_training.hyperopt\"):\n", | ||
" trials = pickle.load(open(\"models/cpc15_prior_training.hyperopt\", \"rb\"))\n", | ||
"else:\n", | ||
" trials = Trials()\n", | ||
"\n", | ||
"best = fmin(\n", | ||
" fn=f, # \"Loss\" function to minimize\n", | ||
" space=hyperparams, # Hyperparameter space\n", | ||
" algo=tpe.suggest, # Tree-structured Parzen Estimator (TPE)\n", | ||
" max_evals=50, # Amount of trials to perform \n", | ||
" trials=trials\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4275837f", | ||
"metadata": {}, | ||
"source": [ | ||
"### Finetuning\n", | ||
"Finetuning on the other hand is much simpler. We stay with the same batch size, adapt the learning rate to a lower value and do not have to fit the SET parameters anymore. The only thing that we changed between CPC15 and choices13k is the number of episodes. With CPC15, you need much more and somewhere in that training process the random SET procedure can help you to make that last jump to the loss you see in our and Bourgin et al.'s Paper." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b9ff2296", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# insert name of the pretrained model as well as the batch size\n", | ||
"pre_trained = CognitivePriorNetwork(input_shape=12, batch_size=X)\n", | ||
"pre_trained.load(\"../models/wide_pretraining/...\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4f65ee06", | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"pre_trained.X_test = X_cpc15_test\n", | ||
"pre_trained.y_test = y_cpc15_test\n", | ||
"pre_trained.fit(X_cpc15_train,\n", | ||
" y_cpc15_train,\n", | ||
" learning_rate=1e-6, \n", | ||
" verbose=1, \n", | ||
" epochs=3000, # for choices13k, you only need 100 episodes \n", | ||
" patience=3000)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Title | ||
This is the official repository, containing code and data for the paper (title) (currently under review). | ||
|
||
![Explanatory Figure](figure1.pdf) | ||
|
||
## Data | ||
All three ([CPC15](https://economics.agri.huji.ac.il/crc2015/raw-data), [CPC18](https://cpc-18.com/data/) and [choices13k](https://github.com/jcpeterson/choices13k)) datasets used in this study were already publicly available before (click on the names, to see more). | ||
Aggregate versions (containing no individual data anymore), with additional features and some utility columns are stored in the [data folder](./data/). | ||
|
||
## Analysis | ||
The [Analysis notebook](./Analysis.ipynb) shows how to reproduce many of the plots and analysis done in the paper. | ||
|
||
## Models | ||
Pretrained models (under the name that they were shown in Table 1 in the paper) are stored in the [models folder](./models). | ||
The [models notebook](./NNs.ipynb) gives simple examples how to load, save and train the NN models discussed in the paper. | ||
The underlying source code for them is in the [src folder](./src). | ||
|
||
## Usage | ||
To use the notebooks and reproduce our results, install the conda environment, using | ||
``` | ||
conda env create -f environment.yml | ||
conda activate DecisionMaking | ||
``` | ||
|
||
## Citation | ||
The paper is currently under reiew. |
Oops, something went wrong.