Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flag for setting pytorch max_split_size_mb #559

Merged
merged 6 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ The JAX and PyTorch versions of the Criteo, FastMRI, Librispeech, OGBG, and WMT
Since we use PyTorch's [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implementation, there is one Python process for each device. Depending on the hardware and the settings of the cluster, running a TensorFlow input pipeline in each Python process can lead to errors, since too many threads are created in each process. See [this PR thread](https://github.com/mlcommons/algorithmic-efficiency/pull/85) for more details.
While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example.

## Pytorch Conformer CUDA OOM

The conformer pytorch workload may run out of memory in current state. Please set the `submission_runner.py` flag `reduce_pytorch_max_split_size` to `True` as a temporary workaround if you encounter this issue. This will set 'max_split_size_mb:256'. Note that this will adversely impact the performance of the submission on this workload. See [tracking issue](https://github.com/mlcommons/algorithmic-efficiency/issues/497).


# FAQS

## Setup and Platform
Expand Down
8 changes: 4 additions & 4 deletions baselines/shampoo/jax/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh(
alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
if padding_start is not None:
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
matrix.dtype)
ix = (jnp.arange(matrix_size, dtype=jnp.int32)
< padding_start).astype(matrix.dtype)
matrix *= ix[jnp.newaxis, :]
matrix *= ix[:, jnp.newaxis]
identity *= ix
Expand Down Expand Up @@ -1923,8 +1923,8 @@ def _internal_inverse_pth_root_all():
errors = metrics.inverse_pth_root_errors
errors = errors.reshape((-1, 1, 1))
predicate = jnp.logical_or(
jnp.isnan(errors),
errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
jnp.isnan(errors), errors
>= inverse_failure_threshold).astype(new_preconditioners.dtype)
# TODO(rohananil): Check for numerical instabilities.
new_conditional_preconditioners = (
predicate * global_stats.preconditioners +
Expand Down
6 changes: 6 additions & 0 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@
None,
'Value of rng seed. If None, a random seed will'
'be generated from hardware.')
flags.DEFINE_boolean('set_pytorch_max_split_size',
None,
priyakasimbeg marked this conversation as resolved.
Show resolved Hide resolved
'If true, set pytorch max_split_size_mb to 256')
FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()

Expand Down Expand Up @@ -602,6 +605,9 @@ def main(_):
if FLAGS.workload == 'librispeech_conformer':
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85'

if FLAGS.set_pytorch_max_split_size is True:
priyakasimbeg marked this conversation as resolved.
Show resolved Hide resolved
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'

# Extend path according to framework.
workload_metadata['workload_path'] = os.path.join(
BASE_WORKLOADS_DIR,
Expand Down
Loading