Skip to content

Commit

Permalink
tune max split size
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Sep 28, 2023
1 parent 0f20678 commit 6cf192a
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,14 @@
from algorithmic_efficiency.pytorch_utils import sync_ddp_time
from algorithmic_efficiency.workloads import workloads


# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.set_visible_devices([], 'GPU')

# disable only for deepspeech if it works fine for other workloads.
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'

# os.environ["CUDA_VISIBLE_DEVICES"]='0'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"

# TODO(znado): make a nicer registry of workloads that lookup in.
BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=falseos.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"# TODO(znado): make a nicer registry of workloads that lookup in.
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'

# Workload_path will be appended by '_pytorch' or '_jax' automatically.
WORKLOADS = workloads.WORKLOADS
Expand Down Expand Up @@ -209,9 +205,8 @@ def train_once(
model_params, model_state = workload.init_model_fn(
model_init_rng, dropout_rate, aux_dropout_rate)
if FLAGS.framework == 'pytorch' and FLAGS.torch_compile:
compile_error_workloads = ['ogbg', 'criteo1tb']
eager_backend_workloads = [
'librispeech_conformer', 'librispeech_deepspeech'
compile_error_workloads = ['ogbg', 'criteo1tb', 'librispeech_conformer']
eager_backend_workloads = ['librispeech_deepspeech'
]
aot_eager_backend_workloads = []
if FLAGS.workload in compile_error_workloads:
Expand Down Expand Up @@ -422,7 +417,6 @@ def train_once(
_reset_cuda_mem()

train_state['last_step_end_time'] = get_time()

metrics = {'eval_results': eval_results, 'global_step': global_step}

if log_dir is not None:
Expand Down

0 comments on commit 6cf192a

Please sign in to comment.