Skip to content

Commit

Permalink
Merge branch 'master' into fix_attention_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
hxdtest authored Oct 11, 2023
2 parents 1a59765 + 6c86ff3 commit 3c7dc49
Show file tree
Hide file tree
Showing 27 changed files with 224 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/amd-mi200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm5.6
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@ jobs:
- uses: actions/checkout@v3
with:
ref: "master"
- id: setup-venv
uses: ./.github/workflows/setup-venv
- name: Get release version from tag
run: |
echo "RELEASE_VERSION=${GITHUB_REF#refs/*/v}" >> $GITHUB_ENV
- name: Check release version
run: |
pip install packaging
python release/check_release_version.py --release_version ${{ env.RELEASE_VERSION }}
- name: Build DeepSpeed
run: |
DS_BUILD_STRING=" " python setup.py sdist_wheel
DS_BUILD_STRING=" " python setup.py sdist
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
Expand Down
8 changes: 8 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,11 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
int num_group,
int group_size,
cudaStream_t stream);

void launch_dequantize_int8_to_half_experimental(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size,
cudaStream_t stream);
23 changes: 23 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in,
return output;
}

at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
at::Tensor& scale_buffer,
at::Tensor& min_val_buffer,
int num_group,
int group_size)
{
auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto output = torch::empty({num_group, group_size}, output_options);

launch_dequantize_int8_to_half_experimental((uint8_t*)data_in.data_ptr(),
(half*)output.data_ptr(),
(half*)scale_buffer.data_ptr(),
(half*)min_val_buffer.data_ptr(),
num_group,
group_size,
at::cuda::getCurrentCUDAStream());

return output;
}

std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
Expand Down Expand Up @@ -270,6 +290,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("dequantize_int4_to_half_experimental",
&dequantize_int4_to_half_experimental,
"Dequantize int4 to half (experimental)");
m.def("dequantize_int8_to_half_experimental",
&dequantize_int8_to_half_experimental,
"Dequantize int8 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
}
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,54 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
dequantize_int4_to_half<<<num_block, 256, 0, stream>>>(
data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size);
}

template <int N>
__device__ __forceinline__ AlignedArray<half, N> int8_to_half(const AlignedArray<uint8_t, N>& data)
{
AlignedArray<half, N> ret;

#pragma unroll
for (int idx = 0; idx < N; idx += 1) { ret[idx] = half(int(data[idx])); }

return ret;
}

__global__ void dequantize_int8_to_half(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size)
{
using AccessType = AlignedArray<uint8_t, 8>;
using AccessTypeOut = AlignedArray<half, 8>;

for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < num_group * group_size / 8;
idx += blockDim.x * gridDim.x) {
int id_group = idx / (group_size / 8);
AccessType value = reinterpret_cast<AccessType*>(data_in)[idx];
half scale = scale_buffer[id_group];
half min_value = min_val_buffer[id_group];

AccessTypeOut output = int8_to_half(value);
output = divide<half, 8>()(output, scale);
output = plus<half, 8>()(output, min_value);

reinterpret_cast<AccessTypeOut*>(data_out)[idx] = output;
}
}

void launch_dequantize_int8_to_half_experimental(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size,
cudaStream_t stream)
{
int num_warp = num_group / 4;
int num_block = num_warp / 8; // 256 trd / block

dequantize_int8_to_half<<<num_block, 256, 0, stream>>>(
data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size);
}
2 changes: 1 addition & 1 deletion deepspeed/comm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

# DeepSpeed Team

from pydantic import BaseModel
from .constants import *
from ..pydantic_v1 import BaseModel


class CommsConfig(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import os
from datetime import timedelta

#############################################
Expand All @@ -15,6 +16,6 @@
# (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1).
# To make an attempt at backwards compatibility with THD, we use an
# extraordinarily high default timeout, given that THD did not have timeouts.
default_pg_timeout = timedelta(minutes=30)
default_pg_timeout = timedelta(minutes=int(os.getenv("DEEPSPEED_TIMEOUT", default=30)))
INFERENCE_GENERIC_MODE = 'generic'
INFERENCE_SPECIALIZED_MODE = 'specialized'
3 changes: 1 addition & 2 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import torch
import deepspeed
from deepspeed.pydantic_v1 import Field, validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from pydantic import Field
from pydantic import validator
from typing import Dict, Union
from enum import Enum

Expand Down
18 changes: 11 additions & 7 deletions deepspeed/inference/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,24 @@ def __init__(self, config: Dict, dtype: torch.dtype) -> None:
def dequantize(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) -> Tensor:
# Use customized CUDA quantization kernel if possible.
if self.config['group_size'] % 8 == 0 and \
self.config['num_bits'] == 4 and \
(self.config['num_bits'] == 4 or self.config['num_bits'] == 8) and \
self.config['group_dim'] == len(tensor.shape) - 1 and \
self.dtype == torch.float16 and device == 'cuda':

last_dimension_size = self.config['group_size']
if self.config['num_bits'] == 4:
last_dimension_size = last_dimension_size // 2
quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])

shape = list(tensor.shape)
if self.config['num_bits'] == 4:
quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])
shape = list(tensor.shape)
shape[-1] = shape[-1] * 2
elif self.config['num_bits'] == 8:
# last_dimension_size = last_dimension_size // 2
quantized_tensor = get_quantizer_cuda_module().dequantize_int8_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])
shape = list(tensor.shape)

return quantized_tensor.reshape(shape)

Expand Down
20 changes: 19 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw

Expand Down Expand Up @@ -318,6 +318,11 @@ def _replace(self, child, name, conv_linear_layer):
del data

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
child.bias.to(get_accelerator().current_device_name())), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group)
else:
Expand Down Expand Up @@ -436,3 +441,16 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
self.update_mp_params(child)
self._replace_module(child, name, class_name)
return r_module

def _replace_last_linear_module(self, r_module):
if hasattr(r_module, "lm_head"):
name = "lm_head"
child = r_module.lm_head
elif hasattr(r_module, "embed_out"):
name = "embed_out"
child = r_module.embed_out
else:
return r_module
if child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, name, self.conv_linear_layer))
return r_module
30 changes: 30 additions & 0 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,36 @@ def forward(self, input):
return output


class LmHeadLinearAllreduce(nn.Module):

def __init__(
self,
weight,
rank,
world_size,
bias=None,
mp_group=None,
):
super(LmHeadLinearAllreduce, self).__init__()
self.weight = weight
self.bias = bias
self.mp_group = mp_group
self.rank = rank
self.world_size = world_size

def forward(self, input):
assert input.shape[
-1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]'
input_shard = input.shape[-1] // self.world_size
output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard],
self.weight.transpose(-1, -2))
if self.mp_group is not None:
dist.inference_all_reduce(output, group=self.mp_group)
if self.bias is not None:
output += self.bias
return output


class LinearLayer(nn.Module):

def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
Expand Down
11 changes: 10 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
_autotp.update_linear_policies()

# 4. Replace modules
if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears:
return _autotp._replace_last_linear_module(module)
return _autotp._replace_module(module)

def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
Expand Down Expand Up @@ -304,6 +306,13 @@ def set_lm_head(module):
if embedding_weight is not None and hasattr(module, "lm_head") and hasattr(
module.lm_head, "weight") and module.lm_head.weight.is_meta:
module.lm_head.weight = embedding_weight
# enable tensor parallel for the last linear
if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and not module.lm_head.weight.is_meta:
module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
elif hasattr(module, "embed_out") and hasattr(module.embed_out,
"weight") and not module.embed_out.weight.is_meta:
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
return module

if checkpoint_dict is not None and not config.replace_with_kernel_inject:
# AutoTP shard loading
Expand All @@ -318,7 +327,7 @@ def set_lm_head(module):
checkpoint=checkpoint_file)
pbar.update(1)
gc.collect()
set_lm_head(replaced_module)
replaced_module = set_lm_head(replaced_module)
else:
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/monitor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# DeepSpeed Team

from pydantic import root_validator
from deepspeed.pydantic_v1 import root_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel


Expand Down
16 changes: 16 additions & 0 deletions deepspeed/pydantic_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
"""Pydantic v1 compatibility module.
Pydantic v2 introduced breaking changes that hinder its adoption:
https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to
migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module
as a pydantic-version-agnostic alias for pydantic's v1 API.
"""

try:
from pydantic.v1 import * # noqa: F401
except ImportError:
from pydantic import * # noqa: F401
2 changes: 2 additions & 0 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def gather_partitioned_activations(tensors, device=None):
# don't need to do all_gather if model parallel is not enabled
if mp_group is None or mp_size == 1:
item = item.view(list(size.numpy()))
if device is not None:
item = item.to(device)
inputs.append(item)
continue

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import collections
import collections.abc
from functools import reduce
from pydantic import BaseModel
from deepspeed.pydantic_v1 import BaseModel
from deepspeed.utils import logger


Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

# DeepSpeed Team

from pydantic import Field, validator
import sys
from typing import Optional
from enum import Enum
from deepspeed.pydantic_v1 import Field, validator
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
from deepspeed.utils import logger
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/offload_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

# DeepSpeed Team

from pydantic import Field, validator
from enum import Enum
from pathlib import Path
from deepspeed.pydantic_v1 import Field, validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int


Expand Down
2 changes: 1 addition & 1 deletion docs/_tutorials/advanced-install.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Available `DS_BUILD` options include:
* `DS_BUILD_CCL_COMM` builds the communication collective libs
* `DS_BUILD_CPU_ADAM` builds the CPUAdam op
* `DS_BUILD_CPU_LION` builds the CPULion op
* `DS_BUILD_EVOFORMER_ATTN` builds the EvoformerAttn op (from [Alphafold](https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/))
* `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex))
* `DS_BUILD_FUSED_LION` builds the FusedLion op
* `DS_BUILD_CPU_ADAGRAD` builds the CPUAdagrad op
Expand All @@ -71,7 +72,6 @@ Available `DS_BUILD` options include:
* `DS_BUILD_TRANSFORMER` builds the transformer op
* `DS_BUILD_TRANSFORMER_INFERENCE` builds the transformer-inference op
* `DS_BUILD_STOCHASTIC_TRANSFORMER` builds the stochastic transformer op
* `DS_BUILD_UTILS` builds various optimized utilities

To speed up the build-all process, you can parallelize the compilation process with:

Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ dependencies:
- certifi
- openssl
- python=3.10
- pydantic<2.0.0
- pydantic
Loading

0 comments on commit 3c7dc49

Please sign in to comment.