Skip to content

Commit

Permalink
distributed tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Mar 3, 2024
1 parent 8f14e77 commit 768881a
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions docs/tutorials/distributed_computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@ Distributed computation

We present here how to perform computation on multiple devices.

Imagine to have at your disposal 4 GPUs and you want to distribute the workload on them.
There are two ways to do so:

* Create 4 simulators, specifying a different device for each one
* Use the `JAX pmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html>`_ function to wrap the functions you need.

If memory is not an issue, the second method is the easiest one. In fact, you simply need divide your population in groups (i.e. divide the first axis) and distribute over the groups.
Consider a scenario where you have access to four GPUs and aim to distribute the workload effectively among them.
To achieve this, we employ the `JAX pmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html>`_ function, which allows seamless distribution of functions across multiple accelerators.

.. code-block:: python
Expand All @@ -21,7 +16,7 @@ If memory is not an issue, the second method is the easiest one. In fact, you si
# load 200 individuals
population = simulator.load_population(genome)[:200]
# divide them in 4 groups
population = population.reshape(4, 50, *population.shape[1:])
population = population.reshape(4, -1, *population.shape[1:])
# prepare a parallelized function over groups
pmap_dh = jax.pmap(
Expand All @@ -35,7 +30,6 @@ If memory is not an issue, the second method is the easiest one. In fact, you si
dh_pop = dh_pop.reshape(-1, *dh_pop.shape[2:])
If you want to perform random crosses or full diallel, grouping the population will change the semantics (the random crosses or the full diallel will be performed by group independently).
In this case, you should use the function ``cross`` after generating the proper array of parents.
For example, to perform random crosses:
Expand All @@ -51,7 +45,7 @@ For example, to perform random crosses:
random_indices = np.random.random_integers(0, len(population) - 1, size=(200, 2))
parents = population[random_indices]
parents = parents.reshape(4, 50, *parents.shape[1:])
parents = parents.reshape(4, -1, *parents.shape[1:])
pmap_cross = jax.pmap(simulator.cross,)
new_pop = pmap_cross(parents)
new_pop = new_pop.reshape(-1, *new_pop.shape[2:])
Expand Down

0 comments on commit 768881a

Please sign in to comment.