diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 61db1a8f8c9..0de57fb27c6 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -348,7 +348,7 @@ def physical_chip_count(self): def client_create_options(self): return { 'max_inflight_computations': - xu.getenv_as('XLA_TPU_MAX_INFLIGHT_COMPUTATIONS', int, 4), + xu.getenv_as('XLA_TPU_MAX_INFLIGHT_COMPUTATIONS', int, 32), 'ml_framework_name': 'PyTorch/XLA', 'ml_framework_version':