Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Nov 2, 2023
1 parent ec876fa commit 691e2c8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions baselines/shampoo/jax/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh(
alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
if padding_start is not None:
ix = (jnp.arange(matrix_size, dtype=jnp.int32)
< padding_start).astype(matrix.dtype)
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
matrix.dtype)
matrix *= ix[jnp.newaxis, :]
matrix *= ix[:, jnp.newaxis]
identity *= ix
Expand Down Expand Up @@ -1923,8 +1923,8 @@ def _internal_inverse_pth_root_all():
errors = metrics.inverse_pth_root_errors
errors = errors.reshape((-1, 1, 1))
predicate = jnp.logical_or(
jnp.isnan(errors), errors
>= inverse_failure_threshold).astype(new_preconditioners.dtype)
jnp.isnan(errors),
errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
# TODO(rohananil): Check for numerical instabilities.
new_conditional_preconditioners = (
predicate * global_stats.preconditioners +
Expand Down
4 changes: 2 additions & 2 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
'Value of rng seed. If None, a random seed will'
'be generated from hardware.')
flags.DEFINE_boolean('set_pytorch_max_split_size',
None,
False,
'If true, set pytorch max_split_size_mb to 256')
FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
Expand Down Expand Up @@ -605,7 +605,7 @@ def main(_):
if FLAGS.workload == 'librispeech_conformer':
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85'

if FLAGS.set_pytorch_max_split_size is True:
if FLAGS.set_pytorch_max_split_size:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'

# Extend path according to framework.
Expand Down

0 comments on commit 691e2c8

Please sign in to comment.