Skip to content

Commit

Permalink
Get around lazy init issue in test_ds_config_dict.py
Browse files Browse the repository at this point in the history
  • Loading branch information
delock committed Mar 7, 2024
1 parent 2658d41 commit cd8672d
Showing 1 changed file with 36 additions and 4 deletions.
40 changes: 36 additions & 4 deletions tests/unit/runtime/test_ds_config_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ def base_config():
}
},
}
if get_accelerator().is_fp16_supported():
config_dict["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
config_dict["bf16"] = {"enabled": True}
return config_dict


Expand Down Expand Up @@ -164,11 +160,19 @@ class TestConfigLoad(DistributedTest):
world_size = 1

def test_dict(self, base_config):
if get_accelerator().is_fp16_supported():
base_config["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
base_config["bf16"] = {"enabled": True}
hidden_dim = 10
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(config=base_config, model=model, model_parameters=model.parameters())

def test_json(self, base_config, tmpdir):
if get_accelerator().is_fp16_supported():
base_config["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
base_config["bf16"] = {"enabled": True}
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, 'w') as fp:
json.dump(base_config, fp)
Expand All @@ -177,6 +181,10 @@ def test_json(self, base_config, tmpdir):
model, _, _, _ = deepspeed.initialize(config=config_path, model=model, model_parameters=model.parameters())

def test_hjson(self, base_config, tmpdir):
if get_accelerator().is_fp16_supported():
base_config["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
base_config["bf16"] = {"enabled": True}
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, 'w') as fp:
hjson.dump(base_config, fp)
Expand All @@ -189,6 +197,10 @@ class TestDeprecatedDeepScaleConfig(DistributedTest):
world_size = 1

def test(self, base_config, tmpdir):
if get_accelerator().is_fp16_supported():
base_config["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
base_config["bf16"] = {"enabled": True}
config_path = create_config_from_dict(tmpdir, base_config)
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
Expand All @@ -210,6 +222,10 @@ class TestDistInit(DistributedTest):
world_size = 1

def test(self, base_config):
if get_accelerator().is_fp16_supported():
base_config["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
base_config["bf16"] = {"enabled": True}
hidden_dim = 10

model = SimpleModel(hidden_dim)
Expand All @@ -228,6 +244,10 @@ class TestInitNoOptimizer(DistributedTest):
world_size = 1

def test(self, base_config):
if get_accelerator().is_fp16_supported():
base_config["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
base_config["bf16"] = {"enabled": True}
if get_accelerator().device_name() == "cpu":
pytest.skip("This test timeout with CPU accelerator")
del base_config["optimizer"]
Expand All @@ -249,13 +269,21 @@ class TestArgs(DistributedTest):
world_size = 1

def test_none_args(self, base_config):
if get_accelerator().is_fp16_supported():
base_config["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
base_config["bf16"] = {"enabled": True}
model = SimpleModel(hidden_dim=10)
model, _, _, _ = deepspeed.initialize(args=None, model=model, config=base_config)
data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])

def test_no_args(self, base_config):
if get_accelerator().is_fp16_supported():
base_config["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
base_config["bf16"] = {"enabled": True}
model = SimpleModel(hidden_dim=10)
model, _, _, _ = deepspeed.initialize(model=model, config=base_config)
data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, device=model.device)
Expand All @@ -267,6 +295,10 @@ class TestNoModel(DistributedTest):
world_size = 1

def test(self, base_config):
if get_accelerator().is_fp16_supported():
base_config["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
base_config["bf16"] = {"enabled": True}
model = SimpleModel(hidden_dim=10)
with pytest.raises(AssertionError):
model, _, _, _ = deepspeed.initialize(model=None, config=base_config)
Expand Down

0 comments on commit cd8672d

Please sign in to comment.