Skip to content

Commit

Permalink
Merge branch 'master' into fix-rope-theta
Browse files Browse the repository at this point in the history
  • Loading branch information
cupertank authored Oct 12, 2023
2 parents f1b119e + 574fbc0 commit a8e2933
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/amd-mi200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm5.6
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>

* [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)]
* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses) [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/japanese/README.md)]
Expand Down Expand Up @@ -236,6 +236,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
25. Zhewei Yao, Reza Yazdani Aminabadi, Olatunji Ruwase, Samyam Rajbhandari, Xiaoxia Wu, Ammar Ahmad Awan, Jeff Rasley, Minjia Zhang, Conglong Li, Connor Holmes, Zhongzhu Zhou, Michael Wyatt, Molly Smith, Lev Kurilenko, Heyang Qin, Masahiro Tanaka, Shuai Che, Shuaiwen Leon Song, Yuxiong He. (2023) DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales [arXiv:2308.01320](https://arxiv.org/abs/2308.01320).
26. Xiaoxia Wu, Zhewei Yao, Yuxiong He. (2023) ZeroQuant-FP: A Leap Forward in LLMs Post-Training W4A8 Quantization Using Floating-Point Formats [arXiv:2307.09782](https://arxiv.org/abs/2307.09782)
27. Zhewei Yao, Xiaoxia Wu, Conglong Li, Minjia Zhang, Heyang Qin, Olatunji Ruwase, Ammar Ahmad Awan, Samyam Rajbhandari, Yuxiong He. (2023) DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention [arXiv:2309.14327](https://arxiv.org/pdf/2309.14327.pdf)
28. Shuaiwen Leon Song, Bonnie Kruft, Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Masahiro Tanaka, Xiaoxia Wu, Jeff Rasley, Ammar Ahmad Awan, Connor Holmes, Martin Cai, Adam Ghanem, Zhongzhu Zhou, Yuxiong He, et al. (2023) DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies [arXiv:2310.04610](https://arxiv.org/abs/2310.04610)



Expand Down
11 changes: 11 additions & 0 deletions blogs/deepspeed4science/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,14 @@
</div>

[https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)

To cite DeepSpeed4Science, please cite our [white paper](https://arxiv.org/abs/2310.04610):

```
@article{song2023deepspeed4science,
title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
journal={arXiv preprint arXiv:2310.04610},
year={2023}
}
```
11 changes: 11 additions & 0 deletions blogs/deepspeed4science/chinese/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
*图1:DeepSpeed4Science方法概述:专为加速科学发现和应对其复杂性而量身定制的AI系统技术开发。*
</div>

如需引用 DeepSpeed4Science,请引用我们的[white paper](https://arxiv.org/abs/2310.04610):

```
@article{song2023deepspeed4science,
title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
journal={arXiv preprint arXiv:2310.04610},
year={2023}
}
```

## 简介

在接下来的十年中,深度学习可能会彻底改变自然科学,增强我们对自然现象进行建模和预测的能力。这可能预示着科学探索的新时代,为从药物开发到可再生能源的各个领域带来重大进展。为了响应这一机会以及微软“予力全球每一人、每一组织,成就不凡”的使命,[微软DeepSpeed团队](https://www.deepspeed.ai/)启动了一个名为[DeepSpeed4Science](https://deepspeed4science.ai/)的新计划,旨在通过AI系统技术创新帮助领域专家解锁当今最大的科学之谜。
Expand Down
11 changes: 11 additions & 0 deletions blogs/deepspeed4science/japanese/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
*図1:DeepSpeed4Scienceのアプローチ: 汎用の言語モデルのサポートを超え、科学的発見とその複雑さの解決に特化したAI技術を開発*
</div>

DeepSpeed4Science を引用するには、こちらの[white paper](https://arxiv.org/abs/2310.04610)を引用してください:

```
@article{song2023deepspeed4science,
title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
journal={arXiv preprint arXiv:2310.04610},
year={2023}
}
```

## はじめに

自然の出来事をモデル化し予測する深層学習の能力は急速に高まっており、次の10年間に、自然科学に革命を起こすかも知れません。薬の開発から再生可能エネルギーまでの各セクターで、大きな進展をもたらす新しい科学的探求の時代が到来するでしょう。「地球上のすべての人と組織がもっと多くのことを成し遂げられるようにする」というMicrosoftのミッションに従い、この機会に、[DeepSpeedチーム](https://www.deepspeed.ai/)では[DeepSpeed4Science](https://deepspeed4science.ai/)という新しいイニシアティブを立ち上げました。これは、AIシステム技術のイノベーションを通じて他に類を見ない技術を構築し、様々な分野の専門家が、科学分野における大きな謎を解き明かす手助けをすることを目指しています。
Expand Down
8 changes: 8 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,11 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
int num_group,
int group_size,
cudaStream_t stream);

void launch_dequantize_int8_to_half_experimental(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size,
cudaStream_t stream);
23 changes: 23 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in,
return output;
}

at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
at::Tensor& scale_buffer,
at::Tensor& min_val_buffer,
int num_group,
int group_size)
{
auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto output = torch::empty({num_group, group_size}, output_options);

launch_dequantize_int8_to_half_experimental((uint8_t*)data_in.data_ptr(),
(half*)output.data_ptr(),
(half*)scale_buffer.data_ptr(),
(half*)min_val_buffer.data_ptr(),
num_group,
group_size,
at::cuda::getCurrentCUDAStream());

return output;
}

std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
Expand Down Expand Up @@ -270,6 +290,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("dequantize_int4_to_half_experimental",
&dequantize_int4_to_half_experimental,
"Dequantize int4 to half (experimental)");
m.def("dequantize_int8_to_half_experimental",
&dequantize_int8_to_half_experimental,
"Dequantize int8 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
}
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,54 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
dequantize_int4_to_half<<<num_block, 256, 0, stream>>>(
data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size);
}

template <int N>
__device__ __forceinline__ AlignedArray<half, N> int8_to_half(const AlignedArray<uint8_t, N>& data)
{
AlignedArray<half, N> ret;

#pragma unroll
for (int idx = 0; idx < N; idx += 1) { ret[idx] = half(int(data[idx])); }

return ret;
}

__global__ void dequantize_int8_to_half(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size)
{
using AccessType = AlignedArray<uint8_t, 8>;
using AccessTypeOut = AlignedArray<half, 8>;

for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < num_group * group_size / 8;
idx += blockDim.x * gridDim.x) {
int id_group = idx / (group_size / 8);
AccessType value = reinterpret_cast<AccessType*>(data_in)[idx];
half scale = scale_buffer[id_group];
half min_value = min_val_buffer[id_group];

AccessTypeOut output = int8_to_half(value);
output = divide<half, 8>()(output, scale);
output = plus<half, 8>()(output, min_value);

reinterpret_cast<AccessTypeOut*>(data_out)[idx] = output;
}
}

void launch_dequantize_int8_to_half_experimental(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size,
cudaStream_t stream)
{
int num_warp = num_group / 4;
int num_block = num_warp / 8; // 256 trd / block

dequantize_int8_to_half<<<num_block, 256, 0, stream>>>(
data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size);
}
3 changes: 2 additions & 1 deletion deepspeed/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import os
from datetime import timedelta

#############################################
Expand All @@ -15,6 +16,6 @@
# (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1).
# To make an attempt at backwards compatibility with THD, we use an
# extraordinarily high default timeout, given that THD did not have timeouts.
default_pg_timeout = timedelta(minutes=30)
default_pg_timeout = timedelta(minutes=int(os.getenv("DEEPSPEED_TIMEOUT", default=30)))
INFERENCE_GENERIC_MODE = 'generic'
INFERENCE_SPECIALIZED_MODE = 'specialized'
18 changes: 11 additions & 7 deletions deepspeed/inference/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,24 @@ def __init__(self, config: Dict, dtype: torch.dtype) -> None:
def dequantize(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) -> Tensor:
# Use customized CUDA quantization kernel if possible.
if self.config['group_size'] % 8 == 0 and \
self.config['num_bits'] == 4 and \
(self.config['num_bits'] == 4 or self.config['num_bits'] == 8) and \
self.config['group_dim'] == len(tensor.shape) - 1 and \
self.dtype == torch.float16 and device == 'cuda':

last_dimension_size = self.config['group_size']
if self.config['num_bits'] == 4:
last_dimension_size = last_dimension_size // 2
quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])

shape = list(tensor.shape)
if self.config['num_bits'] == 4:
quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])
shape = list(tensor.shape)
shape[-1] = shape[-1] * 2
elif self.config['num_bits'] == 8:
# last_dimension_size = last_dimension_size // 2
quantized_tensor = get_quantizer_cuda_module().dequantize_int8_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])
shape = list(tensor.shape)

return quantized_tensor.reshape(shape)

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,9 +722,9 @@ def _upsample_flops_compute(*args, **kwargs):

flops = input.numel()
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
flops * int(_prod(scale_factor))
flops *= int(_prod(scale_factor))
else:
flops * scale_factor**len(input)
flops *= scale_factor**len(input)
return flops, 0


Expand Down
13 changes: 12 additions & 1 deletion docs/_pages/deepspeed4science.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@ toc_label: "Contents"
toc_sticky: true
---

In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. This page serves as an overview page for all technologies released (or to be released in the future) as part of the DeepSpeed4Science initiative, making it easier for scientists to shop for techniques they need. Details of the DeepSpeed4Science initiative can be found at [our website](https://deepspeed4science.ai/). For each technique we will introduce what is it for, when to use it, links to how to use it, and existing scientific applications of the techniques (we welcome users to contribute more showcases if you apply our techniques in your scientific research):
In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. This page serves as an overview page for all technologies released (or to be released in the future) as part of the DeepSpeed4Science initiative, making it easier for scientists to shop for techniques they need. Details of the DeepSpeed4Science initiative can be found at [our website](https://deepspeed4science.ai/). For each technique we will introduce what is it for, when to use it, links to how to use it, and existing scientific applications of the techniques (we welcome users to contribute more showcases if you apply our techniques in your scientific research).

To cite DeepSpeed4Science, please cite our [white paper](https://arxiv.org/abs/2310.04610):

```
@article{song2023deepspeed4science,
title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
journal={arXiv preprint arXiv:2310.04610},
year={2023}
}
```

* [2023/09] We are releasing two techniques: [DeepSpeed4Science large-scale training framework](#new-megatron-deepspeed-for-large-scale-ai4science-model-training), [DS4Sci_EvoformerAttention](#memory-efficient-evoformerattention-kernels) and their scientific applications in structural biology research.

Expand Down
Loading

0 comments on commit a8e2933

Please sign in to comment.