Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Oct 16, 2023
2 parents 63a5205 + 4fc181b commit 0912101
Show file tree
Hide file tree
Showing 39 changed files with 493 additions and 78 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/amd-mi200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm5.6
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down
50 changes: 50 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: Build and publish DeepSpeed release

on:
push:
tags:
- 'v*.*.*'

jobs:
deploy:
runs-on: ubuntu-20.04
environment: release-env

steps:
- uses: actions/checkout@v3
with:
ref: "master"
- id: setup-venv
uses: ./.github/workflows/setup-venv
- name: Get release version from tag
run: |
echo "RELEASE_VERSION=${GITHUB_REF#refs/*/v}" >> $GITHUB_ENV
- name: Check release version
run: |
pip install packaging
python release/check_release_version.py --release_version ${{ env.RELEASE_VERSION }}
- name: Build DeepSpeed
run: |
DS_BUILD_STRING=" " python setup.py sdist
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
repository-url: https://upload.pypi.org/legacy/
- name: Bump version
run: |
python release/bump_patch_version.py --current_version ${{ env.RELEASE_VERSION }}
- name: Create Pull Request
uses: peter-evans/create-pull-request@v4
with:
token: ${{ secrets.GH_PAT }}
add-paths: |
version.txt
body: |
**Auto-generated PR to update version.txt after a DeepSpeed release**
Released version - ${{ env.RELEASE_VERSION }}
Author - @${{ github.actor }}
branch: AutoPR/${{ env.RELEASE_VERSION }}
assignees: ${{ github.actor }}
title: "Update version.txt after ${{ env.RELEASE_VERSION }} release"
author: ${{ github.actor }} <${{ github.actor }}@users.noreply.github.com>
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>

* [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)]
* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses) [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/japanese/README.md)]
Expand Down Expand Up @@ -236,6 +236,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
25. Zhewei Yao, Reza Yazdani Aminabadi, Olatunji Ruwase, Samyam Rajbhandari, Xiaoxia Wu, Ammar Ahmad Awan, Jeff Rasley, Minjia Zhang, Conglong Li, Connor Holmes, Zhongzhu Zhou, Michael Wyatt, Molly Smith, Lev Kurilenko, Heyang Qin, Masahiro Tanaka, Shuai Che, Shuaiwen Leon Song, Yuxiong He. (2023) DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales [arXiv:2308.01320](https://arxiv.org/abs/2308.01320).
26. Xiaoxia Wu, Zhewei Yao, Yuxiong He. (2023) ZeroQuant-FP: A Leap Forward in LLMs Post-Training W4A8 Quantization Using Floating-Point Formats [arXiv:2307.09782](https://arxiv.org/abs/2307.09782)
27. Zhewei Yao, Xiaoxia Wu, Conglong Li, Minjia Zhang, Heyang Qin, Olatunji Ruwase, Ammar Ahmad Awan, Samyam Rajbhandari, Yuxiong He. (2023) DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention [arXiv:2309.14327](https://arxiv.org/pdf/2309.14327.pdf)
28. Shuaiwen Leon Song, Bonnie Kruft, Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Masahiro Tanaka, Xiaoxia Wu, Jeff Rasley, Ammar Ahmad Awan, Connor Holmes, Martin Cai, Adam Ghanem, Zhongzhu Zhou, Yuxiong He, et al. (2023) DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies [arXiv:2310.04610](https://arxiv.org/abs/2310.04610)



Expand Down
11 changes: 11 additions & 0 deletions blogs/deepspeed4science/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,14 @@
</div>

[https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)

To cite DeepSpeed4Science, please cite our [white paper](https://arxiv.org/abs/2310.04610):

```
@article{song2023deepspeed4science,
title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
journal={arXiv preprint arXiv:2310.04610},
year={2023}
}
```
11 changes: 11 additions & 0 deletions blogs/deepspeed4science/chinese/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
*图1:DeepSpeed4Science方法概述:专为加速科学发现和应对其复杂性而量身定制的AI系统技术开发。*
</div>

如需引用 DeepSpeed4Science,请引用我们的[white paper](https://arxiv.org/abs/2310.04610):

```
@article{song2023deepspeed4science,
title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
journal={arXiv preprint arXiv:2310.04610},
year={2023}
}
```

## 简介

在接下来的十年中,深度学习可能会彻底改变自然科学,增强我们对自然现象进行建模和预测的能力。这可能预示着科学探索的新时代,为从药物开发到可再生能源的各个领域带来重大进展。为了响应这一机会以及微软“予力全球每一人、每一组织,成就不凡”的使命,[微软DeepSpeed团队](https://www.deepspeed.ai/)启动了一个名为[DeepSpeed4Science](https://deepspeed4science.ai/)的新计划,旨在通过AI系统技术创新帮助领域专家解锁当今最大的科学之谜。
Expand Down
11 changes: 11 additions & 0 deletions blogs/deepspeed4science/japanese/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
*図1:DeepSpeed4Scienceのアプローチ: 汎用の言語モデルのサポートを超え、科学的発見とその複雑さの解決に特化したAI技術を開発*
</div>

DeepSpeed4Science を引用するには、こちらの[white paper](https://arxiv.org/abs/2310.04610)を引用してください:

```
@article{song2023deepspeed4science,
title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
journal={arXiv preprint arXiv:2310.04610},
year={2023}
}
```

## はじめに

自然の出来事をモデル化し予測する深層学習の能力は急速に高まっており、次の10年間に、自然科学に革命を起こすかも知れません。薬の開発から再生可能エネルギーまでの各セクターで、大きな進展をもたらす新しい科学的探求の時代が到来するでしょう。「地球上のすべての人と組織がもっと多くのことを成し遂げられるようにする」というMicrosoftのミッションに従い、この機会に、[DeepSpeedチーム](https://www.deepspeed.ai/)では[DeepSpeed4Science](https://deepspeed4science.ai/)という新しいイニシアティブを立ち上げました。これは、AIシステム技術のイノベーションを通じて他に類を見ない技術を構築し、様々な分野の専門家が、科学分野における大きな謎を解き明かす手助けをすることを目指しています。
Expand Down
59 changes: 48 additions & 11 deletions csrc/cpu/comm/ccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,20 @@ int world_size = -1;

std::set<int> _comm_ids;
std::set<int> _colors;
ccl::vector_class<ccl::communicator> _ccl_comms;
std::vector<ccl::communicator> _ccl_comms;
ccl::shared_ptr_class<ccl::kvs> sub_kvs;
std::map<std::vector<int>, int> group_to_comm_id;

ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; }
ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; }
ccl::communicator& _get_comm_from_group(std::vector<int> ranks)
{
if (group_to_comm_id.find(ranks) != group_to_comm_id.end()) {
auto id = group_to_comm_id.find(ranks);
return _ccl_comms[id->second];
}
return _ccl_comms[0];
}

#define CCLCHECK(cmd) \
do { \
Expand Down Expand Up @@ -394,12 +404,29 @@ int next_unique_val(std::set<int> s)
}
}

py::object new_group(std::vector<int> ranks)
std::vector<uint8_t> get_sub_kvs_addr(bool first)
{
if (first) {
sub_kvs = ccl::create_main_kvs();
ccl::kvs::address_type main_addr = sub_kvs->get_address();
auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
return ccl_kvs_addr;
} else {
ccl::kvs::address_type main_addr;
auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
return ccl_kvs_addr;
}
}

void initialize_sub_comm(int size, int rank, torch::Tensor& kvs_data, std::vector<int> ranks)
{
int comm_id = next_unique_val(_comm_ids);
int color = next_unique_val(_colors);
std::cout << "RANK: " << get_rank() << " COMM_ID: " << comm_id << " COLOR: " << color
<< std::endl;
ccl::kvs::address_type main_addr;
if (rank != 0) {
memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size());
sub_kvs = ccl::create_kvs(main_addr);
}
_ccl_comms.push_back(ccl::create_communicator(size, rank, sub_kvs));
group_to_comm_id[ranks] = _ccl_comms.size() - 1;
}

ccl::datatype get_ccl_datatype(c10::ScalarType type)
Expand Down Expand Up @@ -452,7 +479,7 @@ ccl::reduction get_ccl_reduce_op(py::object op, at::Tensor& input)
return ccl_op;
}

void broadcast(torch::Tensor& data, int src, py::object group, bool async_op)
void broadcast(torch::Tensor& data, int src, std::vector<int> group, bool async_op)
{
CCLCHECK(ccl::broadcast(data.data_ptr(),
data.numel(),
Expand All @@ -463,7 +490,7 @@ void broadcast(torch::Tensor& data, int src, py::object group, bool async_op)
}

// TODO: implement torch's async_op behavior, document it.
void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
void all_reduce(torch::Tensor& data, py::object op, std::vector<int> group, bool async_op)
{
CCLCHECK(ccl::allreduce(data.data_ptr(),
data.data_ptr(),
Expand All @@ -477,7 +504,7 @@ void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async
void all_reduce_caching(torch::Tensor& data,
py::object op,
std::string match_id,
py::object group,
std::vector<int> group,
bool async_op)
{
ccl::allreduce_attr attr = ccl::default_allreduce_attr;
Expand Down Expand Up @@ -510,7 +537,7 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes)
}
}

void inference_all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
void inference_all_reduce(torch::Tensor& data, py::object op, std::vector<int> group, bool async_op)
{
static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
Expand Down Expand Up @@ -583,11 +610,18 @@ void inference_all_reduce(torch::Tensor& data, py::object op, py::object group,
}
}

void barrier(py::object group, bool async_op)
void barrier(std::vector<int> group, bool async_op)
{
CCLCHECK(ccl::barrier(_get_comm_from_group(group)).wait());
}

std::vector<std::string> get_available_coll()
{
std::vector<std::string> colls{
"broadcast", "all_reduce", "inference_all_reduce", "all_reduce_caching", "barrier"};
return colls;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("get_kvs_addr", &get_kvs_addr, "create and get main kvs addr");
Expand All @@ -599,4 +633,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation");
m.def("all_reduce_caching", &all_reduce_caching, "ccl all_reduce with caching");
m.def("barrier", &barrier, "barrier");
m.def("initialize_sub_comm", &initialize_sub_comm, "initialize_sub_comm");
m.def("get_sub_kvs_addr", &get_sub_kvs_addr, "get_sub_kvs_addr");
m.def("get_available_coll", &get_available_coll, "get_available_coll");
}
8 changes: 8 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,11 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
int num_group,
int group_size,
cudaStream_t stream);

void launch_dequantize_int8_to_half_experimental(uint8_t* data_in,
half* data_out,
half* scale_buffer,
half* min_val_buffer,
int num_group,
int group_size,
cudaStream_t stream);
23 changes: 23 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in,
return output;
}

at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
at::Tensor& scale_buffer,
at::Tensor& min_val_buffer,
int num_group,
int group_size)
{
auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto output = torch::empty({num_group, group_size}, output_options);

launch_dequantize_int8_to_half_experimental((uint8_t*)data_in.data_ptr(),
(half*)output.data_ptr(),
(half*)scale_buffer.data_ptr(),
(half*)min_val_buffer.data_ptr(),
num_group,
group_size,
at::cuda::getCurrentCUDAStream());

return output;
}

std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
Expand Down Expand Up @@ -270,6 +290,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("dequantize_int4_to_half_experimental",
&dequantize_int4_to_half_experimental,
"Dequantize int4 to half (experimental)");
m.def("dequantize_int8_to_half_experimental",
&dequantize_int8_to_half_experimental,
"Dequantize int8 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
}
Loading

0 comments on commit 0912101

Please sign in to comment.