diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 3a272159a9f6..8fd9359166c2 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -701,6 +701,9 @@ def zero_cpu_offload(self): return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu return False + def zero_partial_offload(self): + return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0) + def zero_sub_group_size(self): return self._config.zero_config.sub_group_size @@ -1580,6 +1583,7 @@ def _configure_zero_optimizer(self, optimizer): offload_optimizer_config=self.zero_offload_optimizer(), offload_param_config=self.zero_offload_param(), sub_group_size=self.zero_sub_group_size(), + offload_ratio=self.zero_partial_offload(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 35d60b5b3290..f16dfd7ac4c0 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -6,7 +6,7 @@ import sys from typing import Optional from enum import Enum -from deepspeed.pydantic_v1 import Field, validator +from deepspeed.pydantic_v1 import Field, validator, root_validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel from deepspeed.utils import logger from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum @@ -300,3 +300,10 @@ def overlap_comm_valid(cls, field_value, values): assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'" field_value = values["stage"] == ZeroStageEnum.weights return field_value + + @root_validator + def offload_ratio_check(cls, values): + offload_config = getattr(values, "offload_optimizer", {}) + if offload_config and offload_config.ratio < 1.0: + assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3." + return values diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index 1bd79412d39f..b7adc13a0ea2 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -92,3 +92,6 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel): def set_pipeline(cls, field_value, values): values["pipeline"] = field_value or values.get("pipeline", False) return field_value + + ratio: float = Field(1.0, ge=0.0, le=1.0) + """ Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3.""" diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 47e453cea192..fd06e2685658 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -20,7 +20,7 @@ from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload -from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper @@ -104,6 +104,7 @@ def __init__( offload_optimizer_config=None, offload_param_config=None, sub_group_size=1000000000000, + offload_ratio=0.0, mpu=None, clip_grad=0.0, gradient_accumulation_dtype=torch.float32, @@ -159,6 +160,7 @@ def __init__( self.offload_param_pin_memory = False self.params_in_nvme_and_cpu = False self.max_params_in_cpu = 0 + self.partial_offload = offload_ratio #num of ranks in a ZeRO param partitioning group self.zero_hpz_partition_size = zero_hpz_partition_size @@ -191,6 +193,23 @@ def __init__( self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config) + # backup fused_adam optimizer init + if self.offload_optimizer and self.partial_offload != 1.0: + backup_gpu_tensor = torch.randn(1, device='cuda').to(self.dtype) + backup_gpu_param = torch.nn.Parameter(backup_gpu_tensor) + assert type(init_optimizer) == DeepSpeedCPUAdam, 'Hybrid Optimizer Only Supports DeepSpeedCPUAdam' + self.backup_optimizer = FusedAdam([backup_gpu_param], + lr=self.optimizer.param_groups[0]["lr"], + bias_correction=self.optimizer.param_groups[0]["bias_correction"], + betas=self.optimizer.param_groups[0]["betas"], + eps=self.optimizer.param_groups[0]["eps"], + weight_decay=self.optimizer.param_groups[0]["weight_decay"], + amsgrad=self.optimizer.param_groups[0]["amsgrad"]) + # Multiple param_groups configs for back-up optimizer + if len(self.optimizer.param_groups) > 1: + for i in range(1, len(self.optimizer.param_groups)): + self.backup_optimizer.add_param_group(self.optimizer.param_groups[i]) + self.module = module self.elastic_checkpoint = elastic_checkpoint @@ -780,6 +799,17 @@ def _create_fp32_partitions(self): nvme_fp32_dest_tensors = [] fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() + # Assign portion of subgroup to cpu, the other to gpu. + if self.offload_optimizer: + self.subgroup_to_device = {} + sub_group_size = len(self.fp16_partitioned_groups_flat) + # print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n") + for i in range(sub_group_size): + if i < int(self.partial_offload * sub_group_size): + self.subgroup_to_device[i] = 'cpu' + else: + self.subgroup_to_device[i] = get_accelerator()._name + for i, tensor in enumerate(self.fp16_partitioned_groups_flat): num_elements = self.fp16_partitioned_groups_flat_numel[i] @@ -816,8 +846,12 @@ def _create_fp32_partitions(self): self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) else: - self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( - self.device).clone().float().detach()) + if self.offload_optimizer: + self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( + self.subgroup_to_device[i]).clone().float().detach()) + else: + self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( + self.device).clone().float().detach()) self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it @@ -886,10 +920,20 @@ def _release_ipg_buffers(self): def _optimizer_step(self, sub_group_id): param_group_id = self.sub_group_to_group_id[sub_group_id] fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] - - self.optimizer.step() - self.optimizer.param_groups[param_group_id]['params'] = [] + if self.offload_optimizer: + cur_device = self.subgroup_to_device[sub_group_id] + if cur_device == 'cpu': + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + cpu_loss = self.optimizer.step() + self.optimizer.param_groups[param_group_id]['params'] = [] + else: + self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param] + gpu_loss = self.backup_optimizer.step() + self.backup_optimizer.param_groups[param_group_id]['params'] = [] + else: + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + self.optimizer.step() + self.optimizer.param_groups[param_group_id]['params'] = [] def _swappable_optimizer_subgroup(self, sub_group_id): if not self.swap_optimizer: @@ -956,7 +1000,7 @@ def initialize_optimizer_states(self): if self.offload_optimizer_pin_memory: subgroup_gradient_buffer = get_accelerator().pin_memory(subgroup_gradient_buffer) - self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer + self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer.to(self.subgroup_to_device[i]) else: self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index cdd18e62a29e..e9d7166b05b3 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -576,6 +576,7 @@ Note that if the value of "device" is not specified or not supported, an asserti "device": "[cpu|nvme]", "nvme_path": "/local_nvme", "pin_memory": [true|false], + "ratio": 0.3, "buffer_count": 4, "fast_init": false } @@ -598,6 +599,12 @@ Note that if the value of "device" is not specified or not supported, an asserti | ---------------------------------------------------------------------------------------------------- | ------- | | Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead. | `false` | +***ratio***: [float] + +| Description | Default | +| ------------------------------------------------------------------- | ------- | +| the ratio of parameters updating (i.e. optimizer step) on CPU side. | 1 | + ***buffer_count***: [integer] | Description | Default | diff --git a/tests/small_model_debugging/partial_offload_test.py b/tests/small_model_debugging/partial_offload_test.py new file mode 100644 index 000000000000..2094448d534d --- /dev/null +++ b/tests/small_model_debugging/partial_offload_test.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import json +import argparse +import torch +import deepspeed +from torch.utils.data.distributed import DistributedSampler +import deepspeed.comm as dist + + +class SimpleModel(torch.nn.Module): + + def __init__(self, hidden_dim, empty_grad=False): + super(SimpleModel, self).__init__() + self.linear = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear4 = torch.nn.Linear(hidden_dim, hidden_dim) + if empty_grad: + self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)]) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + hidden = x + hidden = self.linear(hidden) + hidden = self.linear2(hidden) + hidden = self.linear3(hidden) + hidden = self.linear4(hidden) + return self.cross_entropy_loss(hidden, y) + + +def create_config_from_dict(tmpdir, config_dict): + config_path = os.path.join(tmpdir, 'temp_config.json') + with open(config_path, 'w') as fd: + json.dump(config_dict, fd) + return config_path + + +def get_data_loader(model, total_samples, hidden_dim, device): + batch_size = model.train_micro_batch_size_per_gpu() + train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half) + train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) + sampler = DistributedSampler(train_dataset) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) + return train_loader + + +def get_args(tmpdir, config_dict): + parser = argparse.ArgumentParser() + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument('--zero', type=int, default=0) + args = parser.parse_args() #args='' + + config_dict["zero_optimization"]["stage"] = args.zero + print('config_dict["zero_optimization"]', config_dict["zero_optimization"]) + config_path = create_config_from_dict(tmpdir, config_dict) + + args.deepspeed_config = config_path + return args + + +def print0(msg): + if dist.get_rank() == 0: + print(msg, flush=True) + + +rank = int(os.environ['RANK']) +print('seed:', 2222 + rank) +torch.random.manual_seed(2222 + rank) + +config_dict = { + "train_batch_size": 256, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 15 + }, + "zero_optimization": { + "stage": 0, + "sub_group_size": 8, + "reduce_bucket_size": 20, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True, + "ratio": 0.3 + } + } +} +# "initial_scale_power": 15 +args = get_args('/tmp/', config_dict) +hidden_dim = 4 * 1024 + +model = SimpleModel(hidden_dim, empty_grad=False) + +model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters(), + dist_init_required=True) + + +def print_params(tag, model): + if dist.get_rank() == 0: + for n, p in model.named_parameters(): + print0("{} {}:{}".format(tag, n, p)) + + +data_loader = get_data_loader(model=model, total_samples=1000, hidden_dim=hidden_dim, device=model.device) +#print_params('pre-train', model) +#while True: +for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + if dist.get_rank() == 0: + print("LOSS:", loss.item()) + model.backward(loss) + model.step() + #print_params('step={}'.format(n), model) + if n == 2: break diff --git a/tests/unit/ops/adam/test_hybrid_adam.py b/tests/unit/ops/adam/test_hybrid_adam.py new file mode 100644 index 000000000000..c7ef4890b322 --- /dev/null +++ b/tests/unit/ops/adam/test_hybrid_adam.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import numpy as np +import pytest + +from cpuinfo import get_cpu_info + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam +from deepspeed.ops.op_builder import CPUAdamBuilder +from unit.common import DistributedTest + +if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("hybrid-adam is not compatible", allow_module_level=True) + +pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower() + + +def check_equal(first, second, atol=1e-2, verbose=False): + x = first.detach().numpy() + y = second.detach().numpy() + print("ATOL", atol) + if verbose: + print("x = {}".format(x.flatten())) + print("y = {}".format(y.flatten())) + print('-' * 80) + np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) + + +@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) +@pytest.mark.parametrize('model_size', [8, 16]) +class TestHybridAdam(DistributedTest): + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + def test_hybrid_adam_equal(self, dtype, model_size): + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + ref_data = torch.randn(model_size).to(dtype) + total_data = ref_data.clone().detach() + + ref_param = torch.nn.Parameter(ref_data) + ref_optimizer = DeepSpeedCPUAdam([ref_param]) + + cpu_data, cuda_data = total_data.chunk(2) + cpu_param = torch.nn.Parameter(cpu_data) + cuda_param = torch.nn.Parameter(cuda_data.to(get_accelerator().device_name())) + + cpu_optimizer = DeepSpeedCPUAdam([cpu_param]) + cuda_optimizer = FusedAdam([cuda_param]) + + ref_grad = torch.randn(model_size).to(dtype) + cpu_grad, cuda_grad = ref_grad.clone().detach().chunk(2) + + ref_param.grad = ref_grad + cpu_param.grad = cpu_grad + cuda_param.grad = cuda_grad.to(get_accelerator().device_name()) + + ref_optimizer.step() + cpu_optimizer.step() + cuda_optimizer.step() + + cuda_param_copy = cuda_param.cpu() + + total_param = torch.cat((cpu_param, cuda_param_copy)) + + check_equal(ref_param, total_param) diff --git a/tests/unit/runtime/zero/test_zero_offloadpp.py b/tests/unit/runtime/zero/test_zero_offloadpp.py new file mode 100644 index 000000000000..c376686f8052 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_offloadpp.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import pytest +import deepspeed.comm as dist +from unit.common import DistributedTest +from unit.simple_model import random_dataloader + +import deepspeed + +from deepspeed.runtime.zero.offload_config import DeepSpeedZeroOffloadOptimizerConfig + +import torch.nn as nn + + +class NNModel(nn.Module): + + def __init__(self, h_dim=1024, n_layers=2): + super(NNModel, self).__init__() + self.layers = nn.ModuleList([nn.Linear(h_dim, h_dim) for i in range(n_layers)]) + self.cross_entropy_loss = nn.CrossEntropyLoss() + + def forward(self, x, y): + for layer in self.layers: + x = layer(x) + return self.cross_entropy_loss(x, y) + + +def test_zero_partial_offload_config(): + config = DeepSpeedZeroOffloadOptimizerConfig(**{"ratio": 0.3}) + assert config.ratio == 0.3 + + +#Large sweep along hidden dim, num_layers of different sizes +@pytest.mark.parametrize("h_dim", [1024]) +@pytest.mark.parametrize("n_layers", [4, 8]) +class TestZeroPartialOffloadConfigSweep(DistributedTest): + world_size = 4 + + def test(self, h_dim: int, n_layers: int) -> None: + config_dict = { + "train_batch_size": 256, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 15 + }, + "zero_optimization": { + "stage": 3, + "sub_group_size": 8, + "reduce_bucket_size": 20, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True, + "ratio": 0.3 + } + } + } + + model = NNModel(h_dim, n_layers) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=h_dim, device=model.device) + dist.barrier() + + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step()