Skip to content

Commit

Permalink
add disclaimer on conformer pytorch workload
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Oct 11, 2023
1 parent 07c1d10 commit 2c85991
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Change log

# TODO: algorithmic-efficiency 0.1.0
# TODO: algorithmic-efficiency 0.1.0
First release of AlgoPerf benchmarking code.
Disclaimer: The Conformer Pytorch workload has memory fragmentation issue after upgrading to
Pytorch 2.0.1. To circumvent this issues we have tuned the pytorch memory allocation configuration,
which slows down the workload by a factor of 2x. For submitters, this means that the Conformer Pytorch
submission times will be about 2x compared to an identical jax submissions.
Tracking issue here: see issue/497(https://github.com/mlcommons/algorithmic-efficiency/issues/497).
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,22 @@ The rules for the MLCommons Algorithmic Efficency benchmark can be found in the
If you are interested in contributing to the work of the working group, feel free to [join the weekly meetings](https://mlcommons.org/en/groups/research-algorithms/), open issues. See our [CONTRIBUTING.md](CONTRIBUTING.md) for MLCommons contributing guidelines and setup and workflow instructions.


# Note on shared data pipelines between JAX and PyTorch
# Disclaimers

# Shared data pipelines between JAX and PyTorch

The JAX and PyTorch versions of the Criteo, FastMRI, Librispeech, OGBG, and WMT workloads are using the same TensorFlow input pipelines. Due to differences in how Jax and PyTorch distribute computations across devices, the PyTorch workloads have an additional overhead for these workloads.

Since we use PyTorch's [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implementation, there is one Python process for each device. Depending on the hardware and the settings of the cluster, running a TensorFlow input pipeline in each Python process can lead to errors, since too many threads are created in each process. See [this PR thread](https://github.com/mlcommons/algorithmic-efficiency/pull/85) for more details.
While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example.

# Conformer workload 2x slower in Pytorch vs Jax
The Conformer Pytorch workload has memory fragmentation issue after upgrading to
Pytorch 2.0.1, which led to out of memory errors. To circumvent this issues we have tuned the pytorch
memory allocation configuration, which slows down the workload by a factor of roughly 2x. For submitters, this
means that the Conformer Pytorch submission times will be roughly 2x compared to an identical jax submissions.
Tracking issue here: see issue/497(https://github.com/mlcommons/algorithmic-efficiency/issues/497).

# FAQS
## Setup
### Why do I get a warning that GPU is not found?
Expand Down

0 comments on commit 2c85991

Please sign in to comment.