-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Guanhua/partial offload rebase v2 (#590)
Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Jeff Rasley <[email protected]>
- Loading branch information
1 parent
a591992
commit fedffc5
Showing
8 changed files
with
355 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import os | ||
import json | ||
import argparse | ||
import torch | ||
import deepspeed | ||
from torch.utils.data.distributed import DistributedSampler | ||
import deepspeed.comm as dist | ||
|
||
|
||
class SimpleModel(torch.nn.Module): | ||
|
||
def __init__(self, hidden_dim, empty_grad=False): | ||
super(SimpleModel, self).__init__() | ||
self.linear = torch.nn.Linear(hidden_dim, hidden_dim) | ||
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) | ||
self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim) | ||
self.linear4 = torch.nn.Linear(hidden_dim, hidden_dim) | ||
if empty_grad: | ||
self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)]) | ||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss() | ||
|
||
def forward(self, x, y): | ||
hidden = x | ||
hidden = self.linear(hidden) | ||
hidden = self.linear2(hidden) | ||
hidden = self.linear3(hidden) | ||
hidden = self.linear4(hidden) | ||
return self.cross_entropy_loss(hidden, y) | ||
|
||
|
||
def create_config_from_dict(tmpdir, config_dict): | ||
config_path = os.path.join(tmpdir, 'temp_config.json') | ||
with open(config_path, 'w') as fd: | ||
json.dump(config_dict, fd) | ||
return config_path | ||
|
||
|
||
def get_data_loader(model, total_samples, hidden_dim, device): | ||
batch_size = model.train_micro_batch_size_per_gpu() | ||
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half) | ||
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) | ||
train_dataset = torch.utils.data.TensorDataset(train_data, train_label) | ||
sampler = DistributedSampler(train_dataset) | ||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) | ||
return train_loader | ||
|
||
|
||
def get_args(tmpdir, config_dict): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--local_rank", type=int, default=0) | ||
parser.add_argument('--zero', type=int, default=0) | ||
args = parser.parse_args() #args='' | ||
|
||
config_dict["zero_optimization"]["stage"] = args.zero | ||
print('config_dict["zero_optimization"]', config_dict["zero_optimization"]) | ||
config_path = create_config_from_dict(tmpdir, config_dict) | ||
|
||
args.deepspeed_config = config_path | ||
return args | ||
|
||
|
||
def print0(msg): | ||
if dist.get_rank() == 0: | ||
print(msg, flush=True) | ||
|
||
|
||
rank = int(os.environ['RANK']) | ||
print('seed:', 2222 + rank) | ||
torch.random.manual_seed(2222 + rank) | ||
|
||
config_dict = { | ||
"train_batch_size": 256, | ||
"steps_per_print": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.00015, | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": True, | ||
"initial_scale_power": 15 | ||
}, | ||
"zero_optimization": { | ||
"stage": 0, | ||
"sub_group_size": 8, | ||
"reduce_bucket_size": 20, | ||
"offload_optimizer": { | ||
"device": "cpu", | ||
"pin_memory": True, | ||
"ratio": 0.3 | ||
} | ||
} | ||
} | ||
# "initial_scale_power": 15 | ||
args = get_args('/tmp/', config_dict) | ||
hidden_dim = 4 * 1024 | ||
|
||
model = SimpleModel(hidden_dim, empty_grad=False) | ||
|
||
model, _, _, _ = deepspeed.initialize(args=args, | ||
model=model, | ||
model_parameters=model.parameters(), | ||
dist_init_required=True) | ||
|
||
|
||
def print_params(tag, model): | ||
if dist.get_rank() == 0: | ||
for n, p in model.named_parameters(): | ||
print0("{} {}:{}".format(tag, n, p)) | ||
|
||
|
||
data_loader = get_data_loader(model=model, total_samples=1000, hidden_dim=hidden_dim, device=model.device) | ||
#print_params('pre-train', model) | ||
#while True: | ||
for n, batch in enumerate(data_loader): | ||
loss = model(batch[0], batch[1]) | ||
if dist.get_rank() == 0: | ||
print("LOSS:", loss.item()) | ||
model.backward(loss) | ||
model.step() | ||
#print_params('step={}'.format(n), model) | ||
if n == 2: break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import torch | ||
import numpy as np | ||
import pytest | ||
|
||
from cpuinfo import get_cpu_info | ||
|
||
import deepspeed | ||
from deepspeed.accelerator import get_accelerator | ||
from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam | ||
from deepspeed.ops.op_builder import CPUAdamBuilder | ||
from unit.common import DistributedTest | ||
|
||
if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: | ||
pytest.skip("hybrid-adam is not compatible", allow_module_level=True) | ||
|
||
pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower() | ||
|
||
|
||
def check_equal(first, second, atol=1e-2, verbose=False): | ||
x = first.detach().numpy() | ||
y = second.detach().numpy() | ||
print("ATOL", atol) | ||
if verbose: | ||
print("x = {}".format(x.flatten())) | ||
print("y = {}".format(y.flatten())) | ||
print('-' * 80) | ||
np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) | ||
|
||
|
||
@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) | ||
@pytest.mark.parametrize('model_size', [8, 16]) | ||
class TestHybridAdam(DistributedTest): | ||
world_size = 1 | ||
reuse_dist_env = True | ||
requires_cuda_env = False | ||
if not get_accelerator().is_available(): | ||
init_distributed = False | ||
set_dist_env = False | ||
|
||
@pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") | ||
def test_hybrid_adam_equal(self, dtype, model_size): | ||
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): | ||
pytest.skip("cpu-adam with half precision not supported on AMD CPUs") | ||
|
||
ref_data = torch.randn(model_size).to(dtype) | ||
total_data = ref_data.clone().detach() | ||
|
||
ref_param = torch.nn.Parameter(ref_data) | ||
ref_optimizer = DeepSpeedCPUAdam([ref_param]) | ||
|
||
cpu_data, cuda_data = total_data.chunk(2) | ||
cpu_param = torch.nn.Parameter(cpu_data) | ||
cuda_param = torch.nn.Parameter(cuda_data.to(get_accelerator().device_name())) | ||
|
||
cpu_optimizer = DeepSpeedCPUAdam([cpu_param]) | ||
cuda_optimizer = FusedAdam([cuda_param]) | ||
|
||
ref_grad = torch.randn(model_size).to(dtype) | ||
cpu_grad, cuda_grad = ref_grad.clone().detach().chunk(2) | ||
|
||
ref_param.grad = ref_grad | ||
cpu_param.grad = cpu_grad | ||
cuda_param.grad = cuda_grad.to(get_accelerator().device_name()) | ||
|
||
ref_optimizer.step() | ||
cpu_optimizer.step() | ||
cuda_optimizer.step() | ||
|
||
cuda_param_copy = cuda_param.cpu() | ||
|
||
total_param = torch.cat((cpu_param, cuda_param_copy)) | ||
|
||
check_equal(ref_param, total_param) |
Oops, something went wrong.