Skip to content

Commit

Permalink
Merge pull request #30 from instadeepai/feat/mixed_experience_replay
Browse files Browse the repository at this point in the history
feat: Mixed Experience Replay 🤝
  • Loading branch information
EdanToledo authored Sep 20, 2024
2 parents 3cadd21 + 989976c commit e0199d7
Show file tree
Hide file tree
Showing 6 changed files with 809 additions and 1 deletion.
305 changes: 305 additions & 0 deletions examples/mixer_demonstration.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import flashbax as fbx\n",
"import jax.numpy as jnp\n",
"from jax.tree_util import tree_map\n",
"import jax\n",
"\n",
"key = jax.random.PRNGKey(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TrajectoryBufferSample(experience={'acts': (4, 5, 3), 'obs': (4, 5, 2)})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create our first buffer, with a sample batch size of 4\n",
"buffer_a = fbx.make_trajectory_buffer(\n",
" add_batch_size=1,\n",
" max_length_time_axis=1000,\n",
" min_length_time_axis=5,\n",
" sample_sequence_length=5,\n",
" period=1,\n",
" sample_batch_size=4,\n",
")\n",
"\n",
"timestep = {\n",
" \"obs\": jnp.ones((2)),\n",
" \"acts\": jnp.ones(3),\n",
"}\n",
"\n",
"state_a = buffer_a.init(\n",
" timestep,\n",
")\n",
"for i in range(100):\n",
" # Fill with POSITIVE values\n",
" state_a = jax.jit(buffer_a.add, donate_argnums=0)(\n",
" state_a,\n",
" tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),\n",
" )\n",
"\n",
"sample_a = buffer_a.sample(state_a, key)\n",
"tree_map(lambda x: x.shape, sample_a)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TrajectoryBufferSample(experience={'acts': (16, 5, 3), 'obs': (16, 5, 2)})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create our second buffer, with a sample batch size of 16\n",
"buffer_b = fbx.make_trajectory_buffer(\n",
" add_batch_size=1,\n",
" max_length_time_axis=1000,\n",
" min_length_time_axis=5,\n",
" sample_sequence_length=5,\n",
" period=1,\n",
" sample_batch_size=16,\n",
")\n",
"\n",
"timestep = {\n",
" \"obs\": jnp.ones((2)),\n",
" \"acts\": jnp.ones(3),\n",
"}\n",
"\n",
"state_b = buffer_b.init(\n",
" timestep,\n",
")\n",
"for i in range(100):\n",
" # Fill with NEGATIVE values\n",
" state_b = jax.jit(buffer_b.add, donate_argnums=0)(\n",
" state_b,\n",
" tree_map(lambda x, _i=i: (- x * _i)[None, None, ...], timestep),\n",
" )\n",
"\n",
"sample_b = buffer_b.sample(state_b, key)\n",
"tree_map(lambda x: x.shape, sample_b)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Make the mixer, with a ratio of 1:3 from buffer_a:buffer_b\n",
"mixer = fbx.make_mixer(\n",
" buffers=[buffer_a, buffer_b],\n",
" sample_batch_size=8,\n",
" proportions=[1,3],\n",
")\n",
"\n",
"# jittable sampling!\n",
"mixer_sample = jax.jit(mixer.sample)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TrajectoryBufferSample(experience={'acts': (8, 5, 3), 'obs': (8, 5, 2)})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Sample from the mixer, using the usual flashbax API\n",
"joint_sample = mixer_sample(\n",
" [state_a, state_b],\n",
" key,\n",
")\n",
"\n",
"# Notice the resulting shape\n",
"tree_map(lambda x: x.shape, joint_sample)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TrajectoryBufferSample(experience={'acts': Array([[[90., 90., 90.],\n",
" [91., 91., 91.],\n",
" [92., 92., 92.],\n",
" [93., 93., 93.],\n",
" [94., 94., 94.]],\n",
"\n",
" [[56., 56., 56.],\n",
" [57., 57., 57.],\n",
" [58., 58., 58.],\n",
" [59., 59., 59.],\n",
" [60., 60., 60.]]], dtype=float32), 'obs': Array([[[90., 90.],\n",
" [91., 91.],\n",
" [92., 92.],\n",
" [93., 93.],\n",
" [94., 94.]],\n",
"\n",
" [[56., 56.],\n",
" [57., 57.],\n",
" [58., 58.],\n",
" [59., 59.],\n",
" [60., 60.]]], dtype=float32)})"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Notice how the first 1/4 * 8 = 2 batches are from buffer_a (POSITIVE VALUES)\n",
"tree_map(lambda x: x[0:2], joint_sample)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TrajectoryBufferSample(experience={'acts': Array([[[-34., -34., -34.],\n",
" [-35., -35., -35.],\n",
" [-36., -36., -36.],\n",
" [-37., -37., -37.],\n",
" [-38., -38., -38.]],\n",
"\n",
" [[-88., -88., -88.],\n",
" [-89., -89., -89.],\n",
" [-90., -90., -90.],\n",
" [-91., -91., -91.],\n",
" [-92., -92., -92.]],\n",
"\n",
" [[-30., -30., -30.],\n",
" [-31., -31., -31.],\n",
" [-32., -32., -32.],\n",
" [-33., -33., -33.],\n",
" [-34., -34., -34.]],\n",
"\n",
" [[-11., -11., -11.],\n",
" [-12., -12., -12.],\n",
" [-13., -13., -13.],\n",
" [-14., -14., -14.],\n",
" [-15., -15., -15.]],\n",
"\n",
" [[-78., -78., -78.],\n",
" [-79., -79., -79.],\n",
" [-80., -80., -80.],\n",
" [-81., -81., -81.],\n",
" [-82., -82., -82.]],\n",
"\n",
" [[-15., -15., -15.],\n",
" [-16., -16., -16.],\n",
" [-17., -17., -17.],\n",
" [-18., -18., -18.],\n",
" [-19., -19., -19.]]], dtype=float32), 'obs': Array([[[-34., -34.],\n",
" [-35., -35.],\n",
" [-36., -36.],\n",
" [-37., -37.],\n",
" [-38., -38.]],\n",
"\n",
" [[-88., -88.],\n",
" [-89., -89.],\n",
" [-90., -90.],\n",
" [-91., -91.],\n",
" [-92., -92.]],\n",
"\n",
" [[-30., -30.],\n",
" [-31., -31.],\n",
" [-32., -32.],\n",
" [-33., -33.],\n",
" [-34., -34.]],\n",
"\n",
" [[-11., -11.],\n",
" [-12., -12.],\n",
" [-13., -13.],\n",
" [-14., -14.],\n",
" [-15., -15.]],\n",
"\n",
" [[-78., -78.],\n",
" [-79., -79.],\n",
" [-80., -80.],\n",
" [-81., -81.],\n",
" [-82., -82.]],\n",
"\n",
" [[-15., -15.],\n",
" [-16., -16.],\n",
" [-17., -17.],\n",
" [-18., -18.],\n",
" [-19., -19.]]], dtype=float32)})"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# and how the second 3/4 * 8 = 6 batches are from buffer_b (NEGATIVE VALUES)\n",
"tree_map(lambda x: x[2:], joint_sample)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "flashbax",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 2 additions & 0 deletions flashbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
item_buffer,
make_flat_buffer,
make_item_buffer,
make_mixer,
make_prioritised_flat_buffer,
make_prioritised_item_buffer,
make_prioritised_trajectory_buffer,
make_trajectory_buffer,
make_trajectory_queue,
mixer,
prioritised_flat_buffer,
prioritised_item_buffer,
prioritised_trajectory_buffer,
Expand Down
1 change: 1 addition & 0 deletions flashbax/buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from flashbax.buffers.flat_buffer import make_flat_buffer
from flashbax.buffers.item_buffer import make_item_buffer
from flashbax.buffers.mixer import make_mixer
from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer
from flashbax.buffers.prioritised_item_buffer import make_prioritised_item_buffer
from flashbax.buffers.prioritised_trajectory_buffer import (
Expand Down
Loading

0 comments on commit e0199d7

Please sign in to comment.