-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Former-commit-id: b0d46adcb58c89d7589a1bae9227554fb37b3211 [formerly c3b968dcd4fa403ea186f47beeee1aae34ecb7a2] Former-commit-id: 7122763bafc0fc3f7972ad71af64a56996904a5f
- Loading branch information
1 parent
5d5bd56
commit feaf39f
Showing
14 changed files
with
575 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
We benchmark three baselines for policy gradient method in several different perspectives | ||
1. REINFORCE | ||
2. Actor-Critic/Vanilla Policy Gradient | ||
3. Advantage Actor-Critic (A2C) | ||
This example includes the implementations of the following policy gradient algorithms: | ||
|
||
- [REINFORCE](reinforce) | ||
- [Vanilla Policy Gradient (VPG)](vpg) | ||
- [Advantage Actor-Critic (A2C)](a2c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Advantage Actor Critic (A2C) | ||
|
||
This is an implementation of [A2C](https://blog.openai.com/baselines-acktr-a2c/) algorithm. | ||
|
||
# Usage | ||
|
||
Run the following command to start parallelized training: | ||
|
||
```bash | ||
python main.py | ||
``` | ||
|
||
One could modify [experiment.py](./experiment.py) to quickly set up different configurations. | ||
|
||
# Results | ||
|
||
<img src='data/result.png' width='75%'> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/zuo/Code/lagom/lagom/core/plotter/__init__.py:9: UserWarning: ImageViewer failed to import due to pyglet. \n", | ||
" warnings.warn('ImageViewer failed to import due to pyglet. ')\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"from lagom.experiment import Configurator\n", | ||
"\n", | ||
"from lagom import pickle_load\n", | ||
"\n", | ||
"from lagom.core.plotter import CurvePlot" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>ID</th>\n", | ||
" <th>cuda</th>\n", | ||
" <th>env.id</th>\n", | ||
" <th>env.standardize</th>\n", | ||
" <th>network.hidden_sizes</th>\n", | ||
" <th>algo.lr</th>\n", | ||
" <th>algo.use_lr_scheduler</th>\n", | ||
" <th>algo.gamma</th>\n", | ||
" <th>agent.standardize_Q</th>\n", | ||
" <th>agent.standardize_adv</th>\n", | ||
" <th>...</th>\n", | ||
" <th>agent.constant_std</th>\n", | ||
" <th>agent.std_state_dependent</th>\n", | ||
" <th>agent.init_std</th>\n", | ||
" <th>train.timestep</th>\n", | ||
" <th>train.N</th>\n", | ||
" <th>train.T</th>\n", | ||
" <th>eval.N</th>\n", | ||
" <th>log.record_interval</th>\n", | ||
" <th>log.print_interval</th>\n", | ||
" <th>log.dir</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <th>0</th>\n", | ||
" <td>0</td>\n", | ||
" <td>True</td>\n", | ||
" <td>HalfCheetah-v2</td>\n", | ||
" <td>True</td>\n", | ||
" <td>[64, 64]</td>\n", | ||
" <td>0.001</td>\n", | ||
" <td>True</td>\n", | ||
" <td>0.99</td>\n", | ||
" <td>False</td>\n", | ||
" <td>True</td>\n", | ||
" <td>...</td>\n", | ||
" <td>None</td>\n", | ||
" <td>False</td>\n", | ||
" <td>0.5</td>\n", | ||
" <td>1000000.0</td>\n", | ||
" <td>16</td>\n", | ||
" <td>5</td>\n", | ||
" <td>10</td>\n", | ||
" <td>100</td>\n", | ||
" <td>1000</td>\n", | ||
" <td>logs</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"<p>1 rows × 25 columns</p>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" ID cuda env.id env.standardize network.hidden_sizes algo.lr \\\n", | ||
"0 0 True HalfCheetah-v2 True [64, 64] 0.001 \n", | ||
"\n", | ||
" algo.use_lr_scheduler algo.gamma agent.standardize_Q \\\n", | ||
"0 True 0.99 False \n", | ||
"\n", | ||
" agent.standardize_adv ... agent.constant_std \\\n", | ||
"0 True ... None \n", | ||
"\n", | ||
" agent.std_state_dependent agent.init_std train.timestep train.N train.T \\\n", | ||
"0 False 0.5 1000000.0 16 5 \n", | ||
"\n", | ||
" eval.N log.record_interval log.print_interval log.dir \n", | ||
"0 10 100 1000 logs \n", | ||
"\n", | ||
"[1 rows x 25 columns]" | ||
] | ||
}, | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"log_folder = Path('logs')\n", | ||
"\n", | ||
"list_config = pickle_load(log_folder/'configs.pkl')\n", | ||
"configs = Configurator.to_dataframe(list_config)\n", | ||
"configs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def load_results(log_folder, ID, f):\n", | ||
" p = Path(log_folder)/str(ID)\n", | ||
" \n", | ||
" list_result = []\n", | ||
" for sub in p.iterdir():\n", | ||
" if sub.is_dir() and (sub/f).exists():\n", | ||
" list_result.append(pickle_load(sub/f))\n", | ||
" \n", | ||
" return list_result\n", | ||
"\n", | ||
"\n", | ||
"def get_returns(list_result):\n", | ||
" returns = []\n", | ||
" for result in list_result:\n", | ||
" #x_values = [i['evaluation_iteration'][0] for i in result]\n", | ||
" x_values = [i['accumulated_trained_timesteps'][0] for i in result]\n", | ||
" y_values = [i['average_return'][0] for i in result]\n", | ||
" returns.append([x_values, y_values])\n", | ||
" \n", | ||
" return returns\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"ID = 0\n", | ||
"env_id = configs.loc[configs['ID'] == ID]['env.id'].values[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"list_result = load_results('logs', ID, 'eval_logs.pkl')\n", | ||
"returns = get_returns(list_result)\n", | ||
"x_values, y_values = zip(*returns)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"plot = CurvePlot()\n", | ||
"plot.add('A2C', y_values, xvalues=x_values)\n", | ||
"ax = plot(title=f'A2C on {env_id}', \n", | ||
" xlabel='Iteration', \n", | ||
" ylabel='Mean Episode Reward', \n", | ||
" num_tick=6, \n", | ||
" xscale_magnitude=None)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"ax.figure.savefig('data/result.png')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# REINFORCE | ||
|
||
This is an implementation of [REINFORCE](https://link.springer.com/article/10.1007/BF00992696) algorithm. | ||
|
||
# Usage | ||
|
||
Run the following command to start parallelized training: | ||
|
||
```bash | ||
python main.py | ||
``` | ||
|
||
One could modify [experiment.py](./experiment.py) to quickly set up different configurations. | ||
|
||
# Results | ||
|
||
<img src='data/result.png' width='75%'> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.