Skip to content

Commit

Permalink
Add xla autocast support, update autocast APIs in checkpointing (#8523)
Browse files Browse the repository at this point in the history
  • Loading branch information
savitha-aws authored Jan 2, 2025
1 parent 40efdb7 commit 31919d5
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 12 deletions.
3 changes: 2 additions & 1 deletion test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/test_scan.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/test_grad_checkpoint.py"
run_test "$CDIR/test_grad_checkpoint.py" "$@" --test_autocast
#run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
run_test "$CDIR/eager/test_eager_with_torch_compile.py"
Expand Down Expand Up @@ -326,4 +327,4 @@ if [ "$LOGFILE" != "" ]; then
run_tests 2>&1 | tee $LOGFILE
else
run_tests
fi
fi
24 changes: 19 additions & 5 deletions test/test_grad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import torch_xla.debug.metrics as met
import torch_xla
import torch_xla.utils.checkpoint as checkpoint
import argparse

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--test_autocast', action='store_true')
FLAGS, leftovers = parser.parse_known_args()


def run():
Expand All @@ -21,11 +26,20 @@ def run():
dummy_data = torch.zeros(64, 1024, 14, 14, device=device)
optimizer.zero_grad()
x = dummy_data
for n_l, layer in enumerate(model):
if n_l > 0:
x = checkpoint.checkpoint(layer, x)
else:
x = layer(x)
if FLAGS.test_autocast:
with torch.autocast("xla"):
for n_l, layer in enumerate(model):
if n_l > 0:
x = checkpoint.checkpoint(layer, x)
else:
x = layer(x)
else:
for n_l, layer in enumerate(model):
if n_l > 0:
x = checkpoint.checkpoint(layer, x)
else:
x = layer(x)

dummy_loss = x.sum()
dummy_loss.backward()
optimizer.step()
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_s
python3 "$TEST_CDIR/test_autocast.py"
python3 "$TEST_CDIR/test_fp8.py"
python3 "$TEST_CDIR/test_grad_checkpoint.py"
python3 "$TEST_CDIR/test_grad_checkpoint.py" "$@" --test_autocast
python3 "$TEST_CDIR/dynamo/test_dynamo.py"
python3 "$TEST_CDIR/dynamo/test_dynamo_dynamic_shape.py"
python3 "$TEST_CDIR/spmd/test_spmd_debugging.py"
Expand Down
22 changes: 16 additions & 6 deletions torch_xla/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,24 @@ def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state

# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
ctx.gpu_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"device_type": "cuda",
"enabled": torch.is_autocast_enabled("cuda"),
"dtype": torch.get_autocast_dtype("cuda"),
"cache_enabled": torch.is_autocast_cache_enabled()
}
ctx.cpu_autocast_kwargs = {
"enabled": torch.is_autocast_cpu_enabled(),
"dtype": torch.get_autocast_cpu_dtype(),
"device_type": "cpu",
"enabled": torch.is_autocast_enabled("cpu"),
"dtype": torch.get_autocast_dtype("cpu"),
"cache_enabled": torch.is_autocast_cache_enabled()
}
ctx.xla_autocast_kwargs = {
"device_type": "xla",
"enabled": torch.is_autocast_enabled("xla"),
"dtype": torch.get_autocast_dtype("xla"),
"cache_enabled": torch.is_autocast_cache_enabled()
}
if preserve_rng_state:
Expand Down Expand Up @@ -180,8 +189,9 @@ def backward(ctx, *args):
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
torch.autocast(**ctx.gpu_autocast_kwargs), \
torch.autocast(**ctx.cpu_autocast_kwargs), \
torch.autocast(**ctx.xla_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
Expand Down

0 comments on commit 31919d5

Please sign in to comment.