Skip to content

Commit

Permalink
Merge branch 'master' into fix_visible_devices
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Oct 11, 2023
2 parents 063108a + 6c86ff3 commit 0d61d38
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/amd-mi200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm5.6
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down
8 changes: 8 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,11 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
int num_group,
int group_size,
cudaStream_t stream);

void launch_dequantize_int8_to_half_experimental(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size,
cudaStream_t stream);
23 changes: 23 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in,
return output;
}

at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
at::Tensor& scale_buffer,
at::Tensor& min_val_buffer,
int num_group,
int group_size)
{
auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto output = torch::empty({num_group, group_size}, output_options);

launch_dequantize_int8_to_half_experimental((uint8_t*)data_in.data_ptr(),
(half*)output.data_ptr(),
(half*)scale_buffer.data_ptr(),
(half*)min_val_buffer.data_ptr(),
num_group,
group_size,
at::cuda::getCurrentCUDAStream());

return output;
}

std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
Expand Down Expand Up @@ -270,6 +290,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("dequantize_int4_to_half_experimental",
&dequantize_int4_to_half_experimental,
"Dequantize int4 to half (experimental)");
m.def("dequantize_int8_to_half_experimental",
&dequantize_int8_to_half_experimental,
"Dequantize int8 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
}
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,54 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
dequantize_int4_to_half<<<num_block, 256, 0, stream>>>(
data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size);
}

template <int N>
__device__ __forceinline__ AlignedArray<half, N> int8_to_half(const AlignedArray<uint8_t, N>& data)
{
AlignedArray<half, N> ret;

#pragma unroll
for (int idx = 0; idx < N; idx += 1) { ret[idx] = half(int(data[idx])); }

return ret;
}

__global__ void dequantize_int8_to_half(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size)
{
using AccessType = AlignedArray<uint8_t, 8>;
using AccessTypeOut = AlignedArray<half, 8>;

for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < num_group * group_size / 8;
idx += blockDim.x * gridDim.x) {
int id_group = idx / (group_size / 8);
AccessType value = reinterpret_cast<AccessType*>(data_in)[idx];
half scale = scale_buffer[id_group];
half min_value = min_val_buffer[id_group];

AccessTypeOut output = int8_to_half(value);
output = divide<half, 8>()(output, scale);
output = plus<half, 8>()(output, min_value);

reinterpret_cast<AccessTypeOut*>(data_out)[idx] = output;
}
}

void launch_dequantize_int8_to_half_experimental(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size,
cudaStream_t stream)
{
int num_warp = num_group / 4;
int num_block = num_warp / 8; // 256 trd / block

dequantize_int8_to_half<<<num_block, 256, 0, stream>>>(
data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size);
}
18 changes: 11 additions & 7 deletions deepspeed/inference/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,24 @@ def __init__(self, config: Dict, dtype: torch.dtype) -> None:
def dequantize(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) -> Tensor:
# Use customized CUDA quantization kernel if possible.
if self.config['group_size'] % 8 == 0 and \
self.config['num_bits'] == 4 and \
(self.config['num_bits'] == 4 or self.config['num_bits'] == 8) and \
self.config['group_dim'] == len(tensor.shape) - 1 and \
self.dtype == torch.float16 and device == 'cuda':

last_dimension_size = self.config['group_size']
if self.config['num_bits'] == 4:
last_dimension_size = last_dimension_size // 2
quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])

shape = list(tensor.shape)
if self.config['num_bits'] == 4:
quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])
shape = list(tensor.shape)
shape[-1] = shape[-1] * 2
elif self.config['num_bits'] == 8:
# last_dimension_size = last_dimension_size // 2
quantized_tensor = get_quantizer_cuda_module().dequantize_int8_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])
shape = list(tensor.shape)

return quantized_tensor.reshape(shape)

Expand Down
2 changes: 1 addition & 1 deletion op_builder/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def sources(self):
'csrc/quantization/pt_binding.cpp',
'csrc/quantization/fake_quantizer.cu',
'csrc/quantization/quantize.cu',
'csrc/quantization/quantize_int4.cu',
'csrc/quantization/quantize_intX.cu',
'csrc/quantization/dequantize.cu',
'csrc/quantization/swizzled_quantize.cu',
'csrc/quantization/quant_reduce.cu',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,11 @@ def quantization_test_helper(pre_quant_type: torch.dtype, num_bits: int):
assert mean_diff < 0.15 and max_diff < 0.5, f'Numeric error exceed threshold, mean diff {mean_diff} (threshold 0.15), max diff {max_diff} (threshold 0.5)'


def zero3_post_init_quantization_test_helper(cpu_offload: bool, nvme_offload: bool):
def zero3_post_init_quantization_test_helper(cpu_offload: bool, nvme_offload: bool, bits: int):
import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig

def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool) -> Dict:
bits = 4
def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict:
GB = 1 << 30

ds_config = {
Expand Down Expand Up @@ -143,7 +142,7 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b
return ds_config

hf_config = AutoConfig.from_pretrained('facebook/opt-125m')
ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload)
ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload, bits=bits)

input_ids = torch.ones(1, 16, dtype=torch.int32, device=device)
attention_mask = torch.ones(1, 16, dtype=torch.float32, device=device)
Expand Down Expand Up @@ -171,12 +170,11 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b
assert mean_diff < 0.4, f'Numeric error exceed threshold, relative error {mean_diff} (threshold 0.4)'


def zero3_quantized_initialization_test_helper(cpu_offload: bool, nvme_offload: bool):
def zero3_quantized_initialization_test_helper(cpu_offload: bool, nvme_offload: bool, bits: int):
import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig

def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool) -> Dict:
bits = 4
def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict:
GB = 1 << 30

ds_config = {
Expand Down Expand Up @@ -223,7 +221,7 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b
return ds_config

hf_config = AutoConfig.from_pretrained('facebook/opt-125m')
ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload)
ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload, bits=bits)

input_ids = torch.ones(1, 16, dtype=torch.int32, device=device)
attention_mask = torch.ones(1, 16, dtype=torch.float32, device=device)
Expand All @@ -249,16 +247,26 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b
assert mean_diff < 0.4, f'Numeric error exceed threshold, relative error {mean_diff} (threshold 0.4)'


class TestQuantizedInt4(DistributedTest):
@pytest.fixture(params=[4, 8], ids=["4bits", "8bits"])
def quantization_bits(request):
return request.param

def test_model_quantization(self):

@pytest.fixture(params=[0, 1], ids=["0", "1"])
def group_dim(request):
return request.param


class TestQuantizedInt(DistributedTest):

def test_model_quantization(self, quantization_bits):
reset_random()

config = AutoConfig.from_pretrained('facebook/opt-125m')

with torch.no_grad():
model = OPTDecoderLayer(config).half().to(device)
bits = 4
bits = quantization_bits

ds_config = {
'weight_quantization': {
Expand Down Expand Up @@ -307,7 +315,7 @@ def test_model_quantization(self):
assert type(model.self_attn.out_proj) is QuantizedLinear

@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
def test_quantized_linear(self):
def test_quantized_linear(self, quantization_bits, group_dim):
reset_random()

layers = []
Expand All @@ -326,9 +334,9 @@ def test_quantized_linear(self):
'weight_quantization': {
'post_init_quant': {
'layer': {
'num_bits': 4,
'num_bits': quantization_bits,
'group_size': 64,
'group_dim': 0,
'group_dim': group_dim,
'symmetric': False
}
}
Expand Down Expand Up @@ -368,31 +376,31 @@ def test_half_int8_quantization(self):
quantization_test_helper(torch.float16, 8)

@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
def test_zero3_int4_post_init_quant(self):
def test_zero3_int4_post_init_quant(self, quantization_bits):
reset_random()
zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=False)
zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=False, bits=quantization_bits)

@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
def test_zero3_int4_post_init_quant_cpu_offload(self):
def test_zero3_int4_post_init_quant_cpu_offload(self, quantization_bits):
reset_random()
zero3_post_init_quantization_test_helper(cpu_offload=True, nvme_offload=False)
zero3_post_init_quantization_test_helper(cpu_offload=True, nvme_offload=False, bits=quantization_bits)

@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
def test_zero3_int4_post_init_quant_nvme_offload(self):
reset_random()
zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=True)
zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=True, bits=4)

@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
def test_zero3_int4_quantized_initialization(self):
def test_zero3_int4_quantized_initialization(self, quantization_bits):
reset_random()
zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=False)
zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=False, bits=quantization_bits)

@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
def test_zero3_int4_quantized_initialization_cpu_offload(self):
def test_zero3_int4_quantized_initialization_cpu_offload(self, quantization_bits):
reset_random()
zero3_quantized_initialization_test_helper(cpu_offload=True, nvme_offload=False)
zero3_quantized_initialization_test_helper(cpu_offload=True, nvme_offload=False, bits=quantization_bits)

@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
def test_zero3_int4_quantized_initialization_nvme_offload(self):
reset_random()
zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=True)
zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=True, bits=4)

0 comments on commit 0d61d38

Please sign in to comment.