Skip to content

Commit

Permalink
Merge branch 'master' into mrwyattii/pydantic-2-support
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jul 25, 2024
2 parents 79c0835 + 45b3635 commit 295a806
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cpu-torch-latest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ concurrency:

jobs:
unit-tests:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:

# formatting and basic install on cpu-only machine
unit-tests:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nv-mii.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout bdf36dc
# git checkout bdf36dc
git rev-parse --short HEAD
pip install .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nv-pre-compile-ops.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ concurrency:

jobs:
unit-tests:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
container:
image: deepspeed/gh-builder:ubuntu1804-py38-torch1131-cu116

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
deploy:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
environment: release-env

steps:
Expand Down
6 changes: 5 additions & 1 deletion accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def is_synchronized_device(self):
return False

def use_host_timers(self):
return self.is_synchronized_device()
# WA XPU event will be consolidated in 2.5
if ipex.__version__ < '2.5':
return True
else:
return self.is_synchronized_device()

def resolves_data_dependency(self):
return self.is_synchronized_device()
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
self.module.activation_checkpoint_func = ds_checkpointing.non_reentrant_checkpoint
if self.grid.get_global_rank() == 0:
logger.info(f'CONFIG: activation_checkpoint_func=non_reentrant_checkpoint')
if self.module.activation_checkpoint_interval > 0:
self.module._precompute_checkpointable_values()

self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline

Expand Down
29 changes: 23 additions & 6 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,32 @@ def __init__(self,
#newseed = get_accelerator().initial_seed() + self._grid.get_stage_id()
#ds_utils.set_random_seed(newseed)

self.activation_checkpoint_interval = activation_checkpoint_interval

self.activation_checkpoint_func = activation_checkpoint_func

#storage for precomputed checkpointeble results
self.is_checkpointable_results = []
self.is_checkpointable_results_interval = None

# if configuration use_reentrant = False, self.activation_checkpoint_func will be set to ``checkpointing.non_reentrant_checkpoint``

#with torch.random.fork_rng(devices=[get_accelerator().current_device_name()]):
self._build()
self.to(get_accelerator().device_name(self.local_rank))

self.tied_comms = self._index_tied_modules()
self._synchronize_tied_weights()

self.activation_checkpoint_interval = activation_checkpoint_interval

self.activation_checkpoint_func = activation_checkpoint_func
# if configuration use_reentrant = False, self.activation_checkpoint_func will be set to ``checkpointing.non_reentrant_checkpoint``
def _precompute_checkpointable_values(self):
if self.activation_checkpoint_interval > 0 and self.is_checkpointable_results_interval != self.activation_checkpoint_interval:
num_layers = len(self.forward_funcs)
self.interval_was_zero = False
for start_idx in range(0, num_layers, self.activation_checkpoint_interval):
end_idx = min(start_idx + self.activation_checkpoint_interval, num_layers)
funcs = self.forward_funcs[start_idx:end_idx]
self.is_checkpointable_results.append(self._is_checkpointable(funcs))
self.is_checkpointable_results_interval = self.activation_checkpoint_interval

def _build(self):
specs = self._layer_specs
Expand Down Expand Up @@ -352,7 +367,9 @@ def exec_func(*inputs):
else:
num_layers = len(self.forward_funcs)
x = forward_input
for start_idx in range(0, num_layers, self.activation_checkpoint_interval):
for start_idx, is_checkpointable_result in \
zip(range(0, num_layers, self.activation_checkpoint_interval), self.is_checkpointable_results):

end_idx = min(start_idx + self.activation_checkpoint_interval, num_layers)

funcs = self.forward_funcs[start_idx:end_idx]
Expand All @@ -361,7 +378,7 @@ def exec_func(*inputs):
if not isinstance(x, tuple):
x = (x, )

if self._is_checkpointable(funcs):
if is_checkpointable_result:
x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx), *x)
else:
x = exec_range_func(start_idx, end_idx)(*x)
Expand Down

0 comments on commit 295a806

Please sign in to comment.