From 119f8d7b784690f1d14425f84fac8ce92abc14b0 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 31 Oct 2023 23:39:54 +0000 Subject: [PATCH 1/6] add flag for setting max split size --- README.md | 5 +++++ submission_runner.py | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/README.md b/README.md index 6ffbab6f7..de8ea060d 100644 --- a/README.md +++ b/README.md @@ -126,8 +126,13 @@ To use the Docker container as an interactive virtual environment, you can run a -v $HOME/algorithmic-efficiency:/algorithmic-efficiency \ --gpus all \ --ipc=host \ +<<<<<<< HEAD + \ + -keep_container_alive true +======= \ --keep_container_alive true +>>>>>>> ba5c6f6175a0ce12f23a7f035613d9d1edc0b74a ``` Note: You may have to use double quotes around `algorithmic-efficiency` [path] in the mounting `-v` flag. If the above command fails try replacing the following line: ```bash diff --git a/submission_runner.py b/submission_runner.py index 656599a42..6d4cc98e2 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -149,6 +149,11 @@ None, 'Value of rng seed. If None, a random seed will' 'be generated from hardware.') +flags.DEFINE_boolean( + 'set_pytorch_max_split_size', + None, + 'If true, set pytorch max_split_size_mb to 256' +) FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -601,6 +606,9 @@ def main(_): # Prevent OOM on librispeech conformer. if FLAGS.workload == 'librispeech_conformer': os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' + + if FLAGS.set_pytorch_max_split_size is True: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( From de45bf7fe4d90af16abc58bb685103722fbec44d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 31 Oct 2023 23:46:49 +0000 Subject: [PATCH 2/6] add documentation --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index de8ea060d..289a93dec 100644 --- a/README.md +++ b/README.md @@ -246,6 +246,11 @@ The JAX and PyTorch versions of the Criteo, FastMRI, Librispeech, OGBG, and WMT Since we use PyTorch's [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implementation, there is one Python process for each device. Depending on the hardware and the settings of the cluster, running a TensorFlow input pipeline in each Python process can lead to errors, since too many threads are created in each process. See [this PR thread](https://github.com/mlcommons/algorithmic-efficiency/pull/85) for more details. While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example. +## Pytorch Conformer CUDA OOM + +The conformer pytorch workload may run out of memory in current state. Please set the `submission_runner.py` flag `reduce_pytorch_max_split_size` to `True` as a temporary workaround if you encounter this issue. This will set 'max_split_size_mb:256'. Note that this will adversely impact the performance of the submission on this workload. See [tracking issue](https://github.com/mlcommons/algorithmic-efficiency/issues/497). + + # FAQS ## Setup and Platform From fa23fe840364ba54f081c4eccd3b52c1752e1744 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 31 Oct 2023 23:50:46 +0000 Subject: [PATCH 3/6] formatting --- .../workloads/fastmri/fastmri_pytorch/workload.py | 4 +--- .../imagenet_resnet/imagenet_pytorch/workload.py | 4 +--- .../librispeech_jax/spectrum_augmenter.py | 4 ++-- .../librispeech_pytorch/workload.py | 9 ++++----- algorithmic_efficiency/workloads/mnist/workload.py | 4 +--- .../workloads/wmt/wmt_pytorch/models.py | 4 ++-- baselines/shampoo/jax/distributed_shampoo.py | 12 ++++++------ submission_runner.py | 14 ++++++-------- 8 files changed, 23 insertions(+), 32 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index daaea9e10..c3252feb8 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -247,9 +247,7 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index c0fcaaef3..cc9d2febc 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -282,9 +282,7 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index 2a6f73d4d..c16740629 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights + < multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index c4f4a1247..d2774d3b9 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -227,8 +227,9 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], device=result.device).view( - 1, -1) < result.count_nonzero(dim=1).view(-1, 1) + fin_result.shape[1], + device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( + -1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -296,9 +297,7 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dcc195170..959228755 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -214,8 +214,6 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index b787785a1..dc8ebea90 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -912,8 +912,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) + >= cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) diff --git a/baselines/shampoo/jax/distributed_shampoo.py b/baselines/shampoo/jax/distributed_shampoo.py index 725529cae..21f088c1b 100644 --- a/baselines/shampoo/jax/distributed_shampoo.py +++ b/baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1923,8 +1923,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), - errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), errors + >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + diff --git a/submission_runner.py b/submission_runner.py index 6d4cc98e2..fc826b407 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -149,11 +149,9 @@ None, 'Value of rng seed. If None, a random seed will' 'be generated from hardware.') -flags.DEFINE_boolean( - 'set_pytorch_max_split_size', - None, - 'If true, set pytorch max_split_size_mb to 256' -) +flags.DEFINE_boolean('set_pytorch_max_split_size', + None, + 'If true, set pytorch max_split_size_mb to 256') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -352,8 +350,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() @@ -606,7 +604,7 @@ def main(_): # Prevent OOM on librispeech conformer. if FLAGS.workload == 'librispeech_conformer': os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' - + if FLAGS.set_pytorch_max_split_size is True: os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' From 9b958c36f8630042ca198ae3451239b5832dd90a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 31 Oct 2023 23:54:37 +0000 Subject: [PATCH 4/6] revert formatting --- README.md | 5 ----- .../workloads/fastmri/fastmri_pytorch/workload.py | 4 +++- .../workloads/imagenet_resnet/imagenet_pytorch/workload.py | 4 +++- .../librispeech_jax/spectrum_augmenter.py | 4 ++-- .../librispeech_conformer/librispeech_pytorch/workload.py | 5 ++--- algorithmic_efficiency/workloads/mnist/workload.py | 4 +++- algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py | 4 ++-- baselines/shampoo/jax/distributed_shampoo.py | 4 ++-- submission_runner.py | 4 ++-- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 289a93dec..197ba8a61 100644 --- a/README.md +++ b/README.md @@ -126,13 +126,8 @@ To use the Docker container as an interactive virtual environment, you can run a -v $HOME/algorithmic-efficiency:/algorithmic-efficiency \ --gpus all \ --ipc=host \ -<<<<<<< HEAD - \ - -keep_container_alive true -======= \ --keep_container_alive true ->>>>>>> ba5c6f6175a0ce12f23a7f035613d9d1edc0b74a ``` Note: You may have to use double quotes around `algorithmic-efficiency` [path] in the mounting `-v` flag. If the above command fails try replacing the following line: ```bash diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index c3252feb8..daaea9e10 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -247,7 +247,9 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index cc9d2febc..c0fcaaef3 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -282,7 +282,9 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index c16740629..2a6f73d4d 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights - < multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights < + multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index d2774d3b9..167332ed0 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -227,9 +227,8 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], - device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( - -1, 1) + fin_result.shape[1], device=result.device).view( + 1, -1) < result.count_nonzero(dim=1).view(-1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index 959228755..dcc195170 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -214,6 +214,8 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index dc8ebea90..b787785a1 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -912,8 +912,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) - >= cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) diff --git a/baselines/shampoo/jax/distributed_shampoo.py b/baselines/shampoo/jax/distributed_shampoo.py index 21f088c1b..225454b2c 100644 --- a/baselines/shampoo/jax/distributed_shampoo.py +++ b/baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix diff --git a/submission_runner.py b/submission_runner.py index fc826b407..a40e2090b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -350,8 +350,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() From ec876fa045079fbbfad924acf528b3eead257248 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 2 Nov 2023 17:41:35 +0000 Subject: [PATCH 5/6] formatting --- .../librispeech_conformer/librispeech_pytorch/workload.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 167332ed0..c4f4a1247 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -296,7 +296,9 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) From 691e2c81ab2821531a6a90b89cb88703a363518f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 2 Nov 2023 18:20:03 +0000 Subject: [PATCH 6/6] nits --- baselines/shampoo/jax/distributed_shampoo.py | 8 ++++---- submission_runner.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/baselines/shampoo/jax/distributed_shampoo.py b/baselines/shampoo/jax/distributed_shampoo.py index 225454b2c..725529cae 100644 --- a/baselines/shampoo/jax/distributed_shampoo.py +++ b/baselines/shampoo/jax/distributed_shampoo.py @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1923,8 +1923,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), errors - >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), + errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + diff --git a/submission_runner.py b/submission_runner.py index a40e2090b..d92732145 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -150,7 +150,7 @@ 'Value of rng seed. If None, a random seed will' 'be generated from hardware.') flags.DEFINE_boolean('set_pytorch_max_split_size', - None, + False, 'If true, set pytorch max_split_size_mb to 256') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -605,7 +605,7 @@ def main(_): if FLAGS.workload == 'librispeech_conformer': os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' - if FLAGS.set_pytorch_max_split_size is True: + if FLAGS.set_pytorch_max_split_size: os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # Extend path according to framework.