Skip to content

Commit

Permalink
Unit tests for MiCS (microsoft#4792)
Browse files Browse the repository at this point in the history
In response to the ask from
microsoft#2964 (comment),
I added three more unit tests related to MiCS.

There are two knowledge issues:
- Testing on Torch 2.1.0 triggers `_IllegalWorker` in coalesced all
gather. I made changes to ignore this condition. and Currently, I don't
know the reason.
- The MiCS implementation is not working with offloading, so the failure
in `TestZeroPartialOffloadConfigSweep` is expected.

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
2 people authored and amaurya committed Feb 17, 2024
1 parent d5cf8f8 commit 7af1dd1
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 2 deletions.
9 changes: 8 additions & 1 deletion deepspeed/runtime/zero/mics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,14 @@ def wait(self) -> None:
"""
"""
# let the current stream to op
instrument_w_nvtx(self.allgather_handle.wait)()
try:
instrument_w_nvtx(self.allgather_handle.wait)()
except RuntimeError as e:
log_dist(
f"WARNING: Runtime Error while waiting the collective all-gather, possibly due to the _IllegalWork",
ranks=[0])
log_dist(f"Error message: {e}", ranks=[0])

if self.complete:
return

Expand Down
8 changes: 7 additions & 1 deletion tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,24 @@ def test(self, zero_stage):


# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227
@pytest.mark.parametrize("mics_enabled", [True, False])
class TestZero3RepeatForwardLoop(DistributedTest):
world_size = 1

def test(self, zero_stage=3):
def test(self, mics_enabled, zero_stage=3):
# force all params to be partitioned by forcing threshold=0
mics_shard_size = -1
if mics_enabled:
mics_shard_size = self.world_size

config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
"steps_per_print": 1,
"zero_optimization": {
"stage": zero_stage,
"stage3_param_persistence_threshold": 0,
"mics_shard_size": mics_shard_size,
},
"optimizer": {
"type": "Adam",
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/runtime/zero/test_zero_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,29 @@ def __init__(self, hidden_dim):
assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized"


class TestMiCSGatheredParametersFree(DistributedTest):
world_size = 1

def test(self):
config_dict = {"train_batch_size": 1, "zero_optimization": {"stage": 3, "mics_shard_size": 1}}
hidden_dim = 10

class MyModel(torch.nn.Module):

def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)

with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict):
model = MyModel(hidden_dim)

with deepspeed.zero.GatheredParameters(list(model.parameters())):
assert model.l1.weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor"

# on exit from `GatheredParameters` the gathered params should be freed and not leak memory
assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized"


class TestSerialContext(DistributedTest):
world_size = 1
init_distributed = False
Expand Down
1 change: 1 addition & 0 deletions tests/unit/runtime/zero/test_zero_offloadpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class TestZeroPartialOffloadConfigSweep(DistributedTest):
world_size = 4

def test(self, h_dim: int, n_layers: int) -> None:

config_dict = {
"train_batch_size": 256,
"steps_per_print": 1,
Expand Down

0 comments on commit 7af1dd1

Please sign in to comment.