diff --git a/.github/workflows/amd-mi200.yml b/.github/workflows/amd-mi200.yml index a275225cc5e4..77f33f744ea8 100644 --- a/.github/workflows/amd-mi200.yml +++ b/.github/workflows/amd-mi200.yml @@ -28,7 +28,7 @@ jobs: - name: Install pytorch run: | - pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.4.2 + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm5.6 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 093bc98d41e8..8e016b4169cb 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,15 +14,18 @@ jobs: - uses: actions/checkout@v3 with: ref: "master" + - id: setup-venv + uses: ./.github/workflows/setup-venv - name: Get release version from tag run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/v}" >> $GITHUB_ENV - name: Check release version run: | + pip install packaging python release/check_release_version.py --release_version ${{ env.RELEASE_VERSION }} - name: Build DeepSpeed run: | - DS_BUILD_STRING=" " python setup.py sdist_wheel + DS_BUILD_STRING=" " python setup.py sdist - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: diff --git a/csrc/includes/quantization.h b/csrc/includes/quantization.h index d2873abf1839..45828832d8d2 100644 --- a/csrc/includes/quantization.h +++ b/csrc/includes/quantization.h @@ -98,3 +98,11 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in, int num_group, int group_size, cudaStream_t stream); + +void launch_dequantize_int8_to_half_experimental(uint8_t* data_in, + half* data_out, + half* scale_buffer, + half* min_val_buffer, + int num_group, + int group_size, + cudaStream_t stream); diff --git a/csrc/quantization/pt_binding.cpp b/csrc/quantization/pt_binding.cpp index d4c253ee005d..a4210897092d 100644 --- a/csrc/quantization/pt_binding.cpp +++ b/csrc/quantization/pt_binding.cpp @@ -156,6 +156,26 @@ at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in, return output; } +at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in, + at::Tensor& scale_buffer, + at::Tensor& min_val_buffer, + int num_group, + int group_size) +{ + auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto output = torch::empty({num_group, group_size}, output_options); + + launch_dequantize_int8_to_half_experimental((uint8_t*)data_in.data_ptr(), + (half*)output.data_ptr(), + (half*)scale_buffer.data_ptr(), + (half*)min_val_buffer.data_ptr(), + num_group, + group_size, + at::cuda::getCurrentCUDAStream()); + + return output; +} + std::vector ds_swizzle_quant(at::Tensor& input_vals, int groups, int num_bits, @@ -270,6 +290,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("dequantize_int4_to_half_experimental", &dequantize_int4_to_half_experimental, "Dequantize int4 to half (experimental)"); + m.def("dequantize_int8_to_half_experimental", + &dequantize_int8_to_half_experimental, + "Dequantize int8 to half (experimental)"); m.def("swizzle_quant", &ds_swizzle_quant); m.def("quantized_reduction", &quantized_reduction); } diff --git a/csrc/quantization/quantize_int4.cu b/csrc/quantization/quantize_intX.cu similarity index 76% rename from csrc/quantization/quantize_int4.cu rename to csrc/quantization/quantize_intX.cu index fed707c1fa7c..b26151ab5c8c 100644 --- a/csrc/quantization/quantize_int4.cu +++ b/csrc/quantization/quantize_intX.cu @@ -228,3 +228,54 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in, dequantize_int4_to_half<<>>( data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size); } + +template +__device__ __forceinline__ AlignedArray int8_to_half(const AlignedArray& data) +{ + AlignedArray ret; + +#pragma unroll + for (int idx = 0; idx < N; idx += 1) { ret[idx] = half(int(data[idx])); } + + return ret; +} + +__global__ void dequantize_int8_to_half(uint8_t* data_in, + half* data_out, + half* scale_buffer, + half* min_val_buffer, + int num_group, + int group_size) +{ + using AccessType = AlignedArray; + using AccessTypeOut = AlignedArray; + + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < num_group * group_size / 8; + idx += blockDim.x * gridDim.x) { + int id_group = idx / (group_size / 8); + AccessType value = reinterpret_cast(data_in)[idx]; + half scale = scale_buffer[id_group]; + half min_value = min_val_buffer[id_group]; + + AccessTypeOut output = int8_to_half(value); + output = divide()(output, scale); + output = plus()(output, min_value); + + reinterpret_cast(data_out)[idx] = output; + } +} + +void launch_dequantize_int8_to_half_experimental(uint8_t* data_in, + half* data_out, + half* scale_buffer, + half* min_val_buffer, + int num_group, + int group_size, + cudaStream_t stream) +{ + int num_warp = num_group / 4; + int num_block = num_warp / 8; // 256 trd / block + + dequantize_int8_to_half<<>>( + data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size); +} diff --git a/deepspeed/comm/config.py b/deepspeed/comm/config.py index 138badebe5a9..1c441bb6bfe9 100644 --- a/deepspeed/comm/config.py +++ b/deepspeed/comm/config.py @@ -3,8 +3,8 @@ # DeepSpeed Team -from pydantic import BaseModel from .constants import * +from ..pydantic_v1 import BaseModel class CommsConfig(BaseModel): diff --git a/deepspeed/constants.py b/deepspeed/constants.py index 7ebc8f9983a5..30135f41b7b6 100644 --- a/deepspeed/constants.py +++ b/deepspeed/constants.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import os from datetime import timedelta ############################################# @@ -15,6 +16,6 @@ # (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1). # To make an attempt at backwards compatibility with THD, we use an # extraordinarily high default timeout, given that THD did not have timeouts. -default_pg_timeout = timedelta(minutes=30) +default_pg_timeout = timedelta(minutes=int(os.getenv("DEEPSPEED_TIMEOUT", default=30))) INFERENCE_GENERIC_MODE = 'generic' INFERENCE_SPECIALIZED_MODE = 'specialized' diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index 70f1c0dbd5b7..1d5018aaa75b 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -5,10 +5,9 @@ import torch import deepspeed +from deepspeed.pydantic_v1 import Field, validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.zero.config import DeepSpeedZeroConfig -from pydantic import Field -from pydantic import validator from typing import Dict, Union from enum import Enum diff --git a/deepspeed/inference/quantization/utils.py b/deepspeed/inference/quantization/utils.py index d47eb265c214..712abc384a44 100644 --- a/deepspeed/inference/quantization/utils.py +++ b/deepspeed/inference/quantization/utils.py @@ -105,20 +105,24 @@ def __init__(self, config: Dict, dtype: torch.dtype) -> None: def dequantize(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) -> Tensor: # Use customized CUDA quantization kernel if possible. if self.config['group_size'] % 8 == 0 and \ - self.config['num_bits'] == 4 and \ + (self.config['num_bits'] == 4 or self.config['num_bits'] == 8) and \ self.config['group_dim'] == len(tensor.shape) - 1 and \ self.dtype == torch.float16 and device == 'cuda': last_dimension_size = self.config['group_size'] if self.config['num_bits'] == 4: last_dimension_size = last_dimension_size // 2 - quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental( - tensor.reshape(-1, last_dimension_size), quant_scale, quant_min, - tensor.numel() // last_dimension_size, self.config['group_size']) - - shape = list(tensor.shape) - if self.config['num_bits'] == 4: + quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental( + tensor.reshape(-1, last_dimension_size), quant_scale, quant_min, + tensor.numel() // last_dimension_size, self.config['group_size']) + shape = list(tensor.shape) shape[-1] = shape[-1] * 2 + elif self.config['num_bits'] == 8: + # last_dimension_size = last_dimension_size // 2 + quantized_tensor = get_quantizer_cuda_module().dequantize_int8_to_half_experimental( + tensor.reshape(-1, last_dimension_size), quant_scale, quant_min, + tensor.numel() // last_dimension_size, self.config['group_size']) + shape = list(tensor.shape) return quantized_tensor.reshape(shape) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index daf143919558..2e348de63454 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -11,7 +11,7 @@ from typing import Optional import torch from deepspeed import comm as dist -from .layers import LinearAllreduce, LinearLayer +from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce from deepspeed.accelerator import get_accelerator from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw @@ -318,6 +318,11 @@ def _replace(self, child, name, conv_linear_layer): del data setattr(child, "replaced", True) + if name == "lm_head" or name == 'embed_out': + return LmHeadLinearAllreduce( + torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(), + child.bias if child.bias is None else torch.nn.parameter.Parameter( + child.bias.to(get_accelerator().current_device_name())), self.mp_group) return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \ torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group) else: @@ -436,3 +441,16 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''): self.update_mp_params(child) self._replace_module(child, name, class_name) return r_module + + def _replace_last_linear_module(self, r_module): + if hasattr(r_module, "lm_head"): + name = "lm_head" + child = r_module.lm_head + elif hasattr(r_module, "embed_out"): + name = "embed_out" + child = r_module.embed_out + else: + return r_module + if child.__class__ in self.linear_policies: + setattr(r_module, name, self.linear_policies[child.__class__](child, name, self.conv_linear_layer)) + return r_module diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index aa29651ec4cf..7a565560dec9 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -29,6 +29,36 @@ def forward(self, input): return output +class LmHeadLinearAllreduce(nn.Module): + + def __init__( + self, + weight, + rank, + world_size, + bias=None, + mp_group=None, + ): + super(LmHeadLinearAllreduce, self).__init__() + self.weight = weight + self.bias = bias + self.mp_group = mp_group + self.rank = rank + self.world_size = world_size + + def forward(self, input): + assert input.shape[ + -1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]' + input_shard = input.shape[-1] // self.world_size + output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard], + self.weight.transpose(-1, -2)) + if self.mp_group is not None: + dist.inference_all_reduce(output, group=self.mp_group) + if self.bias is not None: + output += self.bias + return output + + class LinearLayer(nn.Module): def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None): diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index f9dd921b9ae1..8666372fa3f4 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -275,6 +275,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): _autotp.update_linear_policies() # 4. Replace modules + if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears: + return _autotp._replace_last_linear_module(module) return _autotp._replace_module(module) def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): @@ -304,6 +306,13 @@ def set_lm_head(module): if embedding_weight is not None and hasattr(module, "lm_head") and hasattr( module.lm_head, "weight") and module.lm_head.weight.is_meta: module.lm_head.weight = embedding_weight + # enable tensor parallel for the last linear + if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and not module.lm_head.weight.is_meta: + module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head") + elif hasattr(module, "embed_out") and hasattr(module.embed_out, + "weight") and not module.embed_out.weight.is_meta: + module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out") + return module if checkpoint_dict is not None and not config.replace_with_kernel_inject: # AutoTP shard loading @@ -318,7 +327,7 @@ def set_lm_head(module): checkpoint=checkpoint_file) pbar.update(1) gc.collect() - set_lm_head(replaced_module) + replaced_module = set_lm_head(replaced_module) else: replaced_module = replace_module(model=model, orig_class=orig_layer_impl, diff --git a/deepspeed/monitor/config.py b/deepspeed/monitor/config.py index 2706764290fd..5a8ca6ecf5cd 100644 --- a/deepspeed/monitor/config.py +++ b/deepspeed/monitor/config.py @@ -3,7 +3,7 @@ # DeepSpeed Team -from pydantic import root_validator +from deepspeed.pydantic_v1 import root_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel diff --git a/deepspeed/pydantic_v1.py b/deepspeed/pydantic_v1.py new file mode 100644 index 000000000000..6aba072ad929 --- /dev/null +++ b/deepspeed/pydantic_v1.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Pydantic v1 compatibility module. + +Pydantic v2 introduced breaking changes that hinder its adoption: +https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to +migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module +as a pydantic-version-agnostic alias for pydantic's v1 API. +""" + +try: + from pydantic.v1 import * # noqa: F401 +except ImportError: + from pydantic import * # noqa: F401 diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 77407a52026a..108cb37b57fb 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -270,6 +270,8 @@ def gather_partitioned_activations(tensors, device=None): # don't need to do all_gather if model parallel is not enabled if mp_group is None or mp_size == 1: item = item.view(list(size.numpy())) + if device is not None: + item = item.to(device) inputs.append(item) continue diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index 0fb1372deac8..5522a8e79d69 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -9,7 +9,7 @@ import collections import collections.abc from functools import reduce -from pydantic import BaseModel +from deepspeed.pydantic_v1 import BaseModel from deepspeed.utils import logger diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 1fc11f0e46f5..35d60b5b3290 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -3,10 +3,10 @@ # DeepSpeed Team -from pydantic import Field, validator import sys from typing import Optional from enum import Enum +from deepspeed.pydantic_v1 import Field, validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel from deepspeed.utils import logger from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index c3a6dc7af530..1bd79412d39f 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -3,9 +3,9 @@ # DeepSpeed Team -from pydantic import Field, validator from enum import Enum from pathlib import Path +from deepspeed.pydantic_v1 import Field, validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int diff --git a/docs/_tutorials/advanced-install.md b/docs/_tutorials/advanced-install.md index c2b4c04cad1c..10197e62f681 100755 --- a/docs/_tutorials/advanced-install.md +++ b/docs/_tutorials/advanced-install.md @@ -61,6 +61,7 @@ Available `DS_BUILD` options include: * `DS_BUILD_CCL_COMM` builds the communication collective libs * `DS_BUILD_CPU_ADAM` builds the CPUAdam op * `DS_BUILD_CPU_LION` builds the CPULion op +* `DS_BUILD_EVOFORMER_ATTN` builds the EvoformerAttn op (from [Alphafold](https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/)) * `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex)) * `DS_BUILD_FUSED_LION` builds the FusedLion op * `DS_BUILD_CPU_ADAGRAD` builds the CPUAdagrad op @@ -71,7 +72,6 @@ Available `DS_BUILD` options include: * `DS_BUILD_TRANSFORMER` builds the transformer op * `DS_BUILD_TRANSFORMER_INFERENCE` builds the transformer-inference op * `DS_BUILD_STOCHASTIC_TRANSFORMER` builds the stochastic transformer op -* `DS_BUILD_UTILS` builds various optimized utilities To speed up the build-all process, you can parallelize the compilation process with: diff --git a/environment.yml b/environment.yml index e55fe96e5a5a..28c298717d80 100644 --- a/environment.yml +++ b/environment.yml @@ -18,4 +18,4 @@ dependencies: - certifi - openssl - python=3.10 - - pydantic<2.0.0 + - pydantic diff --git a/op_builder/async_io.py b/op_builder/async_io.py index da511a0a8c9d..b55c821910b9 100644 --- a/op_builder/async_io.py +++ b/op_builder/async_io.py @@ -5,7 +5,6 @@ import distutils.spawn import subprocess -import torch from .builder import OpBuilder @@ -36,6 +35,7 @@ def cxx_args(self): # -O0 for improved debugging, since performance is bound by I/O CPU_ARCH = self.cpu_arch() SIMD_WIDTH = self.simd_width() + import torch # Keep this import here to avoid errors when building DeepSpeed wheel without torch installed TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[0:2]) if TORCH_MAJOR >= 2 and TORCH_MINOR >= 1: CPP_STD = '-std=c++17' diff --git a/op_builder/quantizer.py b/op_builder/quantizer.py index ada80b8f3331..fd765b743de0 100644 --- a/op_builder/quantizer.py +++ b/op_builder/quantizer.py @@ -22,7 +22,7 @@ def sources(self): 'csrc/quantization/pt_binding.cpp', 'csrc/quantization/fake_quantizer.cu', 'csrc/quantization/quantize.cu', - 'csrc/quantization/quantize_int4.cu', + 'csrc/quantization/quantize_intX.cu', 'csrc/quantization/dequantize.cu', 'csrc/quantization/swizzled_quantize.cu', 'csrc/quantization/quant_reduce.cu', diff --git a/requirements/requirements-readthedocs.txt b/requirements/requirements-readthedocs.txt index a6d7915e0ea5..fcd0ec5a9a6a 100644 --- a/requirements/requirements-readthedocs.txt +++ b/requirements/requirements-readthedocs.txt @@ -1,9 +1,9 @@ -autodoc_pydantic<2.0.0 +autodoc_pydantic docutils<0.18 hjson packaging psutil py-cpuinfo -pydantic<2.0.0 +pydantic torch tqdm diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 8c5e76750573..6840d6dbcc98 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,6 +4,6 @@ numpy packaging>=20.0 psutil py-cpuinfo -pydantic<2.0.0 +pydantic torch tqdm diff --git a/tests/unit/inference/quantization/test_int4_quantization.py b/tests/unit/inference/quantization/test_intX_quantization.py similarity index 91% rename from tests/unit/inference/quantization/test_int4_quantization.py rename to tests/unit/inference/quantization/test_intX_quantization.py index 56a5a7d48382..56df2b232d15 100644 --- a/tests/unit/inference/quantization/test_int4_quantization.py +++ b/tests/unit/inference/quantization/test_intX_quantization.py @@ -53,12 +53,11 @@ def quantization_test_helper(pre_quant_type: torch.dtype, num_bits: int): assert mean_diff < 0.15 and max_diff < 0.5, f'Numeric error exceed threshold, mean diff {mean_diff} (threshold 0.15), max diff {max_diff} (threshold 0.5)' -def zero3_post_init_quantization_test_helper(cpu_offload: bool, nvme_offload: bool): +def zero3_post_init_quantization_test_helper(cpu_offload: bool, nvme_offload: bool, bits: int): import deepspeed from transformers.deepspeed import HfDeepSpeedConfig - def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool) -> Dict: - bits = 4 + def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict: GB = 1 << 30 ds_config = { @@ -143,7 +142,7 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b return ds_config hf_config = AutoConfig.from_pretrained('facebook/opt-125m') - ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload) + ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload, bits=bits) input_ids = torch.ones(1, 16, dtype=torch.int32, device=device) attention_mask = torch.ones(1, 16, dtype=torch.float32, device=device) @@ -171,12 +170,11 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b assert mean_diff < 0.4, f'Numeric error exceed threshold, relative error {mean_diff} (threshold 0.4)' -def zero3_quantized_initialization_test_helper(cpu_offload: bool, nvme_offload: bool): +def zero3_quantized_initialization_test_helper(cpu_offload: bool, nvme_offload: bool, bits: int): import deepspeed from transformers.deepspeed import HfDeepSpeedConfig - def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool) -> Dict: - bits = 4 + def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict: GB = 1 << 30 ds_config = { @@ -223,7 +221,7 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b return ds_config hf_config = AutoConfig.from_pretrained('facebook/opt-125m') - ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload) + ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload, bits=bits) input_ids = torch.ones(1, 16, dtype=torch.int32, device=device) attention_mask = torch.ones(1, 16, dtype=torch.float32, device=device) @@ -249,16 +247,26 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b assert mean_diff < 0.4, f'Numeric error exceed threshold, relative error {mean_diff} (threshold 0.4)' -class TestQuantizedInt4(DistributedTest): +@pytest.fixture(params=[4, 8], ids=["4bits", "8bits"]) +def quantization_bits(request): + return request.param - def test_model_quantization(self): + +@pytest.fixture(params=[0, 1], ids=["0", "1"]) +def group_dim(request): + return request.param + + +class TestQuantizedInt(DistributedTest): + + def test_model_quantization(self, quantization_bits): reset_random() config = AutoConfig.from_pretrained('facebook/opt-125m') with torch.no_grad(): model = OPTDecoderLayer(config).half().to(device) - bits = 4 + bits = quantization_bits ds_config = { 'weight_quantization': { @@ -307,7 +315,7 @@ def test_model_quantization(self): assert type(model.self_attn.out_proj) is QuantizedLinear @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') - def test_quantized_linear(self): + def test_quantized_linear(self, quantization_bits, group_dim): reset_random() layers = [] @@ -326,9 +334,9 @@ def test_quantized_linear(self): 'weight_quantization': { 'post_init_quant': { 'layer': { - 'num_bits': 4, + 'num_bits': quantization_bits, 'group_size': 64, - 'group_dim': 0, + 'group_dim': group_dim, 'symmetric': False } } @@ -368,31 +376,31 @@ def test_half_int8_quantization(self): quantization_test_helper(torch.float16, 8) @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') - def test_zero3_int4_post_init_quant(self): + def test_zero3_int4_post_init_quant(self, quantization_bits): reset_random() - zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=False) + zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=False, bits=quantization_bits) @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') - def test_zero3_int4_post_init_quant_cpu_offload(self): + def test_zero3_int4_post_init_quant_cpu_offload(self, quantization_bits): reset_random() - zero3_post_init_quantization_test_helper(cpu_offload=True, nvme_offload=False) + zero3_post_init_quantization_test_helper(cpu_offload=True, nvme_offload=False, bits=quantization_bits) @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') def test_zero3_int4_post_init_quant_nvme_offload(self): reset_random() - zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=True) + zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=True, bits=4) @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') - def test_zero3_int4_quantized_initialization(self): + def test_zero3_int4_quantized_initialization(self, quantization_bits): reset_random() - zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=False) + zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=False, bits=quantization_bits) @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') - def test_zero3_int4_quantized_initialization_cpu_offload(self): + def test_zero3_int4_quantized_initialization_cpu_offload(self, quantization_bits): reset_random() - zero3_quantized_initialization_test_helper(cpu_offload=True, nvme_offload=False) + zero3_quantized_initialization_test_helper(cpu_offload=True, nvme_offload=False, bits=quantization_bits) @pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM') def test_zero3_int4_quantized_initialization_nvme_offload(self): reset_random() - zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=True) + zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=True, bits=4) diff --git a/tests/unit/runtime/test_ds_config_model.py b/tests/unit/runtime/test_ds_config_model.py index b9c67c9a30dd..87ea747cf423 100644 --- a/tests/unit/runtime/test_ds_config_model.py +++ b/tests/unit/runtime/test_ds_config_model.py @@ -6,8 +6,8 @@ import pytest import os import json -from pydantic import Field, ValidationError from typing import List +from deepspeed.pydantic_v1 import Field, ValidationError from deepspeed.runtime import config as ds_config from deepspeed.runtime.config_utils import DeepSpeedConfigModel diff --git a/version.txt b/version.txt index af88ba824866..bc859cbd6d99 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.11.1 +0.11.2