From 31919d54206687debe69978ad8250ab81bcaef3e Mon Sep 17 00:00:00 2001 From: savitha-aws Date: Thu, 2 Jan 2025 13:11:46 -0800 Subject: [PATCH] Add xla autocast support, update autocast APIs in checkpointing (#8523) --- test/neuron/run_tests.sh | 3 ++- test/test_grad_checkpoint.py | 24 +++++++++++++++++++----- test/tpu/run_tests.sh | 1 + torch_xla/utils/checkpoint.py | 22 ++++++++++++++++------ 4 files changed, 38 insertions(+), 12 deletions(-) diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index 57da0ff799a..93af6393ce8 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -184,6 +184,7 @@ function run_xla_op_tests2 { run_test "$CDIR/test_scan.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/test_grad_checkpoint.py" + run_test "$CDIR/test_grad_checkpoint.py" "$@" --test_autocast #run_test "$CDIR/eager/test_eager.py" run_test "$CDIR/eager/test_eager_with_xla_compile.py" run_test "$CDIR/eager/test_eager_with_torch_compile.py" @@ -326,4 +327,4 @@ if [ "$LOGFILE" != "" ]; then run_tests 2>&1 | tee $LOGFILE else run_tests -fi +fi \ No newline at end of file diff --git a/test/test_grad_checkpoint.py b/test/test_grad_checkpoint.py index 9a5fd19aa99..63761a32a40 100644 --- a/test/test_grad_checkpoint.py +++ b/test/test_grad_checkpoint.py @@ -3,6 +3,11 @@ import torch_xla.debug.metrics as met import torch_xla import torch_xla.utils.checkpoint as checkpoint +import argparse + +parser = argparse.ArgumentParser(add_help=False) +parser.add_argument('--test_autocast', action='store_true') +FLAGS, leftovers = parser.parse_known_args() def run(): @@ -21,11 +26,20 @@ def run(): dummy_data = torch.zeros(64, 1024, 14, 14, device=device) optimizer.zero_grad() x = dummy_data - for n_l, layer in enumerate(model): - if n_l > 0: - x = checkpoint.checkpoint(layer, x) - else: - x = layer(x) + if FLAGS.test_autocast: + with torch.autocast("xla"): + for n_l, layer in enumerate(model): + if n_l > 0: + x = checkpoint.checkpoint(layer, x) + else: + x = layer(x) + else: + for n_l, layer in enumerate(model): + if n_l > 0: + x = checkpoint.checkpoint(layer, x) + else: + x = layer(x) + dummy_loss = x.sum() dummy_loss.backward() optimizer.step() diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index c6e90e9ae0a..03c32924c9b 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -22,6 +22,7 @@ XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_s python3 "$TEST_CDIR/test_autocast.py" python3 "$TEST_CDIR/test_fp8.py" python3 "$TEST_CDIR/test_grad_checkpoint.py" +python3 "$TEST_CDIR/test_grad_checkpoint.py" "$@" --test_autocast python3 "$TEST_CDIR/dynamo/test_dynamo.py" python3 "$TEST_CDIR/dynamo/test_dynamo_dynamic_shape.py" python3 "$TEST_CDIR/spmd/test_spmd_debugging.py" diff --git a/torch_xla/utils/checkpoint.py b/torch_xla/utils/checkpoint.py index 1bdf4bdd0a8..220dbe01188 100644 --- a/torch_xla/utils/checkpoint.py +++ b/torch_xla/utils/checkpoint.py @@ -87,15 +87,24 @@ def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. ctx.gpu_autocast_kwargs = { - "enabled": torch.is_autocast_enabled(), - "dtype": torch.get_autocast_gpu_dtype(), + "device_type": "cuda", + "enabled": torch.is_autocast_enabled("cuda"), + "dtype": torch.get_autocast_dtype("cuda"), "cache_enabled": torch.is_autocast_cache_enabled() } ctx.cpu_autocast_kwargs = { - "enabled": torch.is_autocast_cpu_enabled(), - "dtype": torch.get_autocast_cpu_dtype(), + "device_type": "cpu", + "enabled": torch.is_autocast_enabled("cpu"), + "dtype": torch.get_autocast_dtype("cpu"), + "cache_enabled": torch.is_autocast_cache_enabled() + } + ctx.xla_autocast_kwargs = { + "device_type": "xla", + "enabled": torch.is_autocast_enabled("xla"), + "dtype": torch.get_autocast_dtype("xla"), "cache_enabled": torch.is_autocast_cache_enabled() } if preserve_rng_state: @@ -180,8 +189,9 @@ def backward(ctx, *args): set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) detached_inputs = detach_variable(tuple(inputs)) with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ - torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + torch.autocast(**ctx.gpu_autocast_kwargs), \ + torch.autocast(**ctx.cpu_autocast_kwargs), \ + torch.autocast(**ctx.xla_autocast_kwargs): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor):