diff --git a/.github/workflows/amd-mi200.yml b/.github/workflows/amd-mi200.yml
index a275225cc5e4..77f33f744ea8 100644
--- a/.github/workflows/amd-mi200.yml
+++ b/.github/workflows/amd-mi200.yml
@@ -28,7 +28,7 @@ jobs:
- name: Install pytorch
run: |
- pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
+ pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm5.6
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 000000000000..8e016b4169cb
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,50 @@
+name: Build and publish DeepSpeed release
+
+on:
+ push:
+ tags:
+ - 'v*.*.*'
+
+jobs:
+ deploy:
+ runs-on: ubuntu-20.04
+ environment: release-env
+
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ ref: "master"
+ - id: setup-venv
+ uses: ./.github/workflows/setup-venv
+ - name: Get release version from tag
+ run: |
+ echo "RELEASE_VERSION=${GITHUB_REF#refs/*/v}" >> $GITHUB_ENV
+ - name: Check release version
+ run: |
+ pip install packaging
+ python release/check_release_version.py --release_version ${{ env.RELEASE_VERSION }}
+ - name: Build DeepSpeed
+ run: |
+ DS_BUILD_STRING=" " python setup.py sdist
+ - name: Publish to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
+ with:
+ password: ${{ secrets.PYPI_API_TOKEN }}
+ repository-url: https://upload.pypi.org/legacy/
+ - name: Bump version
+ run: |
+ python release/bump_patch_version.py --current_version ${{ env.RELEASE_VERSION }}
+ - name: Create Pull Request
+ uses: peter-evans/create-pull-request@v4
+ with:
+ token: ${{ secrets.GH_PAT }}
+ add-paths: |
+ version.txt
+ body: |
+ **Auto-generated PR to update version.txt after a DeepSpeed release**
+ Released version - ${{ env.RELEASE_VERSION }}
+ Author - @${{ github.actor }}
+ branch: AutoPR/${{ env.RELEASE_VERSION }}
+ assignees: ${{ github.actor }}
+ title: "Update version.txt after ${{ env.RELEASE_VERSION }} release"
+ author: ${{ github.actor }} <${{ github.actor }}@users.noreply.github.com>
diff --git a/README.md b/README.md
index abf20667882b..81ac3031457a 100755
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat).
* [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)]
-* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
+* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses) [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/japanese/README.md)]
@@ -236,6 +236,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
25. Zhewei Yao, Reza Yazdani Aminabadi, Olatunji Ruwase, Samyam Rajbhandari, Xiaoxia Wu, Ammar Ahmad Awan, Jeff Rasley, Minjia Zhang, Conglong Li, Connor Holmes, Zhongzhu Zhou, Michael Wyatt, Molly Smith, Lev Kurilenko, Heyang Qin, Masahiro Tanaka, Shuai Che, Shuaiwen Leon Song, Yuxiong He. (2023) DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales [arXiv:2308.01320](https://arxiv.org/abs/2308.01320).
26. Xiaoxia Wu, Zhewei Yao, Yuxiong He. (2023) ZeroQuant-FP: A Leap Forward in LLMs Post-Training W4A8 Quantization Using Floating-Point Formats [arXiv:2307.09782](https://arxiv.org/abs/2307.09782)
27. Zhewei Yao, Xiaoxia Wu, Conglong Li, Minjia Zhang, Heyang Qin, Olatunji Ruwase, Ammar Ahmad Awan, Samyam Rajbhandari, Yuxiong He. (2023) DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention [arXiv:2309.14327](https://arxiv.org/pdf/2309.14327.pdf)
+28. Shuaiwen Leon Song, Bonnie Kruft, Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Masahiro Tanaka, Xiaoxia Wu, Jeff Rasley, Ammar Ahmad Awan, Connor Holmes, Martin Cai, Adam Ghanem, Zhongzhu Zhou, Yuxiong He, et al. (2023) DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies [arXiv:2310.04610](https://arxiv.org/abs/2310.04610)
diff --git a/blogs/deepspeed4science/README.md b/blogs/deepspeed4science/README.md
index 2a80ea2e749e..a318490329a5 100644
--- a/blogs/deepspeed4science/README.md
+++ b/blogs/deepspeed4science/README.md
@@ -5,3 +5,14 @@
[https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)
+
+To cite DeepSpeed4Science, please cite our [white paper](https://arxiv.org/abs/2310.04610):
+
+```
+@article{song2023deepspeed4science,
+ title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
+ author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
+ journal={arXiv preprint arXiv:2310.04610},
+ year={2023}
+}
+```
diff --git a/blogs/deepspeed4science/chinese/README.md b/blogs/deepspeed4science/chinese/README.md
index 07647c767553..dabc4ab077f2 100644
--- a/blogs/deepspeed4science/chinese/README.md
+++ b/blogs/deepspeed4science/chinese/README.md
@@ -12,6 +12,17 @@
*图1:DeepSpeed4Science方法概述:专为加速科学发现和应对其复杂性而量身定制的AI系统技术开发。*
+如需引用 DeepSpeed4Science,请引用我们的[white paper](https://arxiv.org/abs/2310.04610):
+
+```
+@article{song2023deepspeed4science,
+ title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
+ author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
+ journal={arXiv preprint arXiv:2310.04610},
+ year={2023}
+}
+```
+
## 简介
在接下来的十年中,深度学习可能会彻底改变自然科学,增强我们对自然现象进行建模和预测的能力。这可能预示着科学探索的新时代,为从药物开发到可再生能源的各个领域带来重大进展。为了响应这一机会以及微软“予力全球每一人、每一组织,成就不凡”的使命,[微软DeepSpeed团队](https://www.deepspeed.ai/)启动了一个名为[DeepSpeed4Science](https://deepspeed4science.ai/)的新计划,旨在通过AI系统技术创新帮助领域专家解锁当今最大的科学之谜。
diff --git a/blogs/deepspeed4science/japanese/README.md b/blogs/deepspeed4science/japanese/README.md
index 774ef79a17dc..276528650ab5 100644
--- a/blogs/deepspeed4science/japanese/README.md
+++ b/blogs/deepspeed4science/japanese/README.md
@@ -12,6 +12,17 @@
*図1:DeepSpeed4Scienceのアプローチ: 汎用の言語モデルのサポートを超え、科学的発見とその複雑さの解決に特化したAI技術を開発*
+DeepSpeed4Science を引用するには、こちらの[white paper](https://arxiv.org/abs/2310.04610)を引用してください:
+
+```
+@article{song2023deepspeed4science,
+ title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
+ author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
+ journal={arXiv preprint arXiv:2310.04610},
+ year={2023}
+}
+```
+
## はじめに
自然の出来事をモデル化し予測する深層学習の能力は急速に高まっており、次の10年間に、自然科学に革命を起こすかも知れません。薬の開発から再生可能エネルギーまでの各セクターで、大きな進展をもたらす新しい科学的探求の時代が到来するでしょう。「地球上のすべての人と組織がもっと多くのことを成し遂げられるようにする」というMicrosoftのミッションに従い、この機会に、[DeepSpeedチーム](https://www.deepspeed.ai/)では[DeepSpeed4Science](https://deepspeed4science.ai/)という新しいイニシアティブを立ち上げました。これは、AIシステム技術のイノベーションを通じて他に類を見ない技術を構築し、様々な分野の専門家が、科学分野における大きな謎を解き明かす手助けをすることを目指しています。
diff --git a/csrc/cpu/comm/ccl.cpp b/csrc/cpu/comm/ccl.cpp
index accf431f6929..9a04b6f873f2 100644
--- a/csrc/cpu/comm/ccl.cpp
+++ b/csrc/cpu/comm/ccl.cpp
@@ -277,10 +277,20 @@ int world_size = -1;
std::set _comm_ids;
std::set _colors;
-ccl::vector_class _ccl_comms;
+std::vector _ccl_comms;
+ccl::shared_ptr_class sub_kvs;
+std::map, int> group_to_comm_id;
ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; }
ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; }
+ccl::communicator& _get_comm_from_group(std::vector ranks)
+{
+ if (group_to_comm_id.find(ranks) != group_to_comm_id.end()) {
+ auto id = group_to_comm_id.find(ranks);
+ return _ccl_comms[id->second];
+ }
+ return _ccl_comms[0];
+}
#define CCLCHECK(cmd) \
do { \
@@ -394,12 +404,29 @@ int next_unique_val(std::set s)
}
}
-py::object new_group(std::vector ranks)
+std::vector get_sub_kvs_addr(bool first)
+{
+ if (first) {
+ sub_kvs = ccl::create_main_kvs();
+ ccl::kvs::address_type main_addr = sub_kvs->get_address();
+ auto ccl_kvs_addr = std::vector(main_addr.begin(), main_addr.end());
+ return ccl_kvs_addr;
+ } else {
+ ccl::kvs::address_type main_addr;
+ auto ccl_kvs_addr = std::vector(main_addr.begin(), main_addr.end());
+ return ccl_kvs_addr;
+ }
+}
+
+void initialize_sub_comm(int size, int rank, torch::Tensor& kvs_data, std::vector ranks)
{
- int comm_id = next_unique_val(_comm_ids);
- int color = next_unique_val(_colors);
- std::cout << "RANK: " << get_rank() << " COMM_ID: " << comm_id << " COLOR: " << color
- << std::endl;
+ ccl::kvs::address_type main_addr;
+ if (rank != 0) {
+ memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size());
+ sub_kvs = ccl::create_kvs(main_addr);
+ }
+ _ccl_comms.push_back(ccl::create_communicator(size, rank, sub_kvs));
+ group_to_comm_id[ranks] = _ccl_comms.size() - 1;
}
ccl::datatype get_ccl_datatype(c10::ScalarType type)
@@ -452,7 +479,7 @@ ccl::reduction get_ccl_reduce_op(py::object op, at::Tensor& input)
return ccl_op;
}
-void broadcast(torch::Tensor& data, int src, py::object group, bool async_op)
+void broadcast(torch::Tensor& data, int src, std::vector group, bool async_op)
{
CCLCHECK(ccl::broadcast(data.data_ptr(),
data.numel(),
@@ -463,7 +490,7 @@ void broadcast(torch::Tensor& data, int src, py::object group, bool async_op)
}
// TODO: implement torch's async_op behavior, document it.
-void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
+void all_reduce(torch::Tensor& data, py::object op, std::vector group, bool async_op)
{
CCLCHECK(ccl::allreduce(data.data_ptr(),
data.data_ptr(),
@@ -477,7 +504,7 @@ void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async
void all_reduce_caching(torch::Tensor& data,
py::object op,
std::string match_id,
- py::object group,
+ std::vector group,
bool async_op)
{
ccl::allreduce_attr attr = ccl::default_allreduce_attr;
@@ -510,7 +537,7 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes)
}
}
-void inference_all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
+void inference_all_reduce(torch::Tensor& data, py::object op, std::vector group, bool async_op)
{
static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
@@ -583,11 +610,18 @@ void inference_all_reduce(torch::Tensor& data, py::object op, py::object group,
}
}
-void barrier(py::object group, bool async_op)
+void barrier(std::vector group, bool async_op)
{
CCLCHECK(ccl::barrier(_get_comm_from_group(group)).wait());
}
+std::vector get_available_coll()
+{
+ std::vector colls{
+ "broadcast", "all_reduce", "inference_all_reduce", "all_reduce_caching", "barrier"};
+ return colls;
+}
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("get_kvs_addr", &get_kvs_addr, "create and get main kvs addr");
@@ -599,4 +633,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation");
m.def("all_reduce_caching", &all_reduce_caching, "ccl all_reduce with caching");
m.def("barrier", &barrier, "barrier");
+ m.def("initialize_sub_comm", &initialize_sub_comm, "initialize_sub_comm");
+ m.def("get_sub_kvs_addr", &get_sub_kvs_addr, "get_sub_kvs_addr");
+ m.def("get_available_coll", &get_available_coll, "get_available_coll");
}
diff --git a/csrc/includes/quantization.h b/csrc/includes/quantization.h
index d2873abf1839..45828832d8d2 100644
--- a/csrc/includes/quantization.h
+++ b/csrc/includes/quantization.h
@@ -98,3 +98,11 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
int num_group,
int group_size,
cudaStream_t stream);
+
+void launch_dequantize_int8_to_half_experimental(uint8_t* data_in,
+ half* data_out,
+ half* scale_buffer,
+ half* min_val_buffer,
+ int num_group,
+ int group_size,
+ cudaStream_t stream);
diff --git a/csrc/quantization/pt_binding.cpp b/csrc/quantization/pt_binding.cpp
index d4c253ee005d..a4210897092d 100644
--- a/csrc/quantization/pt_binding.cpp
+++ b/csrc/quantization/pt_binding.cpp
@@ -156,6 +156,26 @@ at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in,
return output;
}
+at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
+ at::Tensor& scale_buffer,
+ at::Tensor& min_val_buffer,
+ int num_group,
+ int group_size)
+{
+ auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
+ auto output = torch::empty({num_group, group_size}, output_options);
+
+ launch_dequantize_int8_to_half_experimental((uint8_t*)data_in.data_ptr(),
+ (half*)output.data_ptr(),
+ (half*)scale_buffer.data_ptr(),
+ (half*)min_val_buffer.data_ptr(),
+ num_group,
+ group_size,
+ at::cuda::getCurrentCUDAStream());
+
+ return output;
+}
+
std::vector ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
@@ -270,6 +290,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("dequantize_int4_to_half_experimental",
&dequantize_int4_to_half_experimental,
"Dequantize int4 to half (experimental)");
+ m.def("dequantize_int8_to_half_experimental",
+ &dequantize_int8_to_half_experimental,
+ "Dequantize int8 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
}
diff --git a/csrc/quantization/quantize_int4.cu b/csrc/quantization/quantize_intX.cu
similarity index 76%
rename from csrc/quantization/quantize_int4.cu
rename to csrc/quantization/quantize_intX.cu
index fed707c1fa7c..b26151ab5c8c 100644
--- a/csrc/quantization/quantize_int4.cu
+++ b/csrc/quantization/quantize_intX.cu
@@ -228,3 +228,54 @@ void launch_dequantize_int4_to_half_experimental(uint8_t* data_in,
dequantize_int4_to_half<<>>(
data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size);
}
+
+template
+__device__ __forceinline__ AlignedArray int8_to_half(const AlignedArray& data)
+{
+ AlignedArray ret;
+
+#pragma unroll
+ for (int idx = 0; idx < N; idx += 1) { ret[idx] = half(int(data[idx])); }
+
+ return ret;
+}
+
+__global__ void dequantize_int8_to_half(uint8_t* data_in,
+ half* data_out,
+ half* scale_buffer,
+ half* min_val_buffer,
+ int num_group,
+ int group_size)
+{
+ using AccessType = AlignedArray;
+ using AccessTypeOut = AlignedArray;
+
+ for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < num_group * group_size / 8;
+ idx += blockDim.x * gridDim.x) {
+ int id_group = idx / (group_size / 8);
+ AccessType value = reinterpret_cast(data_in)[idx];
+ half scale = scale_buffer[id_group];
+ half min_value = min_val_buffer[id_group];
+
+ AccessTypeOut output = int8_to_half(value);
+ output = divide()(output, scale);
+ output = plus()(output, min_value);
+
+ reinterpret_cast(data_out)[idx] = output;
+ }
+}
+
+void launch_dequantize_int8_to_half_experimental(uint8_t* data_in,
+ half* data_out,
+ half* scale_buffer,
+ half* min_val_buffer,
+ int num_group,
+ int group_size,
+ cudaStream_t stream)
+{
+ int num_warp = num_group / 4;
+ int num_block = num_warp / 8; // 256 trd / block
+
+ dequantize_int8_to_half<<>>(
+ data_in, data_out, scale_buffer, min_val_buffer, num_group, group_size);
+}
diff --git a/deepspeed/comm/ccl.py b/deepspeed/comm/ccl.py
index 65646f6ba646..859e70c7db3b 100644
--- a/deepspeed/comm/ccl.py
+++ b/deepspeed/comm/ccl.py
@@ -21,6 +21,16 @@ def build_ccl_op():
return ccl_cpp_module
+class CCLHandler():
+
+ def __init__(self, ccl_comm_op=None):
+ self.ccl_comm_op = ccl_comm_op
+
+ def wait(self):
+ # backend covered it
+ pass
+
+
class CCLBackend(TorchBackend):
def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, init_method=None):
@@ -39,27 +49,129 @@ def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, i
size = self.get_world_size()
rank = self.get_rank()
main_kvs = self.ccl_comm_op.get_kvs_addr(rank)
- main_kvs = torch.tensor(main_kvs).to(torch.uint8)
+ main_kvs = torch.tensor(main_kvs).to(torch.uint8).to(get_accelerator().current_device_name())
super(CCLBackend, self).broadcast(main_kvs, 0)
self.ccl_comm_op.initialize(size, rank, main_kvs)
self.initialized = True
+ self.groups = [tuple(range(self.get_world_size()))]
+ self.available_coll = self.ccl_comm_op.get_available_coll()
def is_initialized(self):
return self.initialized
- def broadcast(self, tensor, src, group=None, async_op=False):
- self.ccl_comm_op.broadcast(tensor, src, group, async_op)
+ def run_collective(self, name, **kwargs):
+ if name in self.available_coll:
+ kwargs['group'] = self.get_all_ranks_from_group(kwargs['group'])
+ if 'dst' in kwargs:
+ kwargs['dst'] = kwargs['group'].index(kwargs['dst'])
+ if 'src' in kwargs:
+ kwargs['src'] = kwargs['group'].index(kwargs['src'])
+ func = "self.ccl_comm_op." + name
+ eval(func)(*(kwargs.values()))
+ return CCLHandler(self.ccl_comm_op)
+ else:
+ func = "super(CCLBackend, self)." + name
+ return eval(func)(*(kwargs.values()))
def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
use_caching = False
if use_caching:
match_id = f"{tensor.size()}-{op}"
- self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op)
+ return self.run_collective(name="all_reduce_caching",
+ tensor=tensor,
+ op=op,
+ match_id=match_id,
+ group=group,
+ async_op=async_op)
else:
- self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
+ return self.run_collective(name="all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)
def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
- self.ccl_comm_op.inference_all_reduce(tensor, op, group, async_op)
+ return self.run_collective(name="inference_all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)
+
+ def broadcast(self, tensor, src, group=None, async_op=False):
+ return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op)
+
+ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
+ return self.run_collective(name="all_gather",
+ tensor_list=tensor_list,
+ tensor=tensor,
+ group=group,
+ async_op=async_op)
+
+ def reduce_scatter_tensor(self, output_tensor, input_tensor, op, group=None, async_op=False):
+ return self.run_collective(name="reduce_scatter_tensor",
+ output_tensor=output_tensor,
+ input_tensor=input_tensor,
+ op=op,
+ group=group)
+
+ def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
+ return self.run_collective(name="all_gather_into_tensor",
+ output_tensor=output_tensor,
+ input_tensor=input_tensor,
+ group=group)
+
+ def all_to_all_single(self, output, input, output_split_sizes, input_split_sizes, group=None, async_op=False):
+ return self.run_collective(name="all_to_all_single",
+ output=output,
+ input=input,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group)
+
+ def send(self, tensor, dst, group=None, async_op=False):
+ return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, async_op=async_op)
+
+ def recv(self, tensor, src, group=None, async_op=False):
+ return self.run_collective(name="recv", tensor=tensor, src=src, group=group, async_op=async_op)
+
+ def gather(self, tensor, gather_list, dst, group=None, async_op=False):
+ return self.run_collective(name="gather", tensor=tensor, gather_list=gather_list, dst=dst, group=group)
+
+ def scatter(self, tensor, gather_list, dst, group=None, async_op=False):
+ return self.run_collective(name="scatter", tensor=tensor, gather_list=gather_list, dst=dst, group=group)
def barrier(self, group=None, async_op=False):
- self.ccl_comm_op.barrier(group, async_op)
+ return self.run_collective(name="barrier", group=group, async_op=async_op)
+
+ def monitored_barrier(self, group=None, timeout=None, wait_all_ranks=False):
+ return self.run_collective(name="monitored_barrier", group=group)
+
+ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
+ return self.run_collective(name="reduce_scatter",
+ output=output,
+ input_list=input_list,
+ op=op,
+ group=group,
+ async_op=async_op)
+
+ def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
+ return self.run_collective(name="reduce", tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
+
+ def new_group(self, ranks):
+ return super(CCLBackend, self).new_group(ranks)
+
+ def _new_group(self, ranks, group):
+ size = len(ranks)
+ rank = self.get_rank()
+ sub_main_kvs = self.ccl_comm_op.get_sub_kvs_addr(rank == ranks[0])
+ sub_main_kvs = torch.tensor(sub_main_kvs).to(torch.uint8).to(get_accelerator().current_device_name())
+ super(CCLBackend, self).broadcast(sub_main_kvs, ranks[0], group)
+ self.ccl_comm_op.initialize_sub_comm(size, ranks.index(rank), sub_main_kvs, ranks)
+ self.groups.append(tuple(ranks))
+
+ def get_all_ranks_from_group(self, group):
+ if group is None:
+ return list(range(self.get_world_size()))
+ rank = 0
+ results = []
+ try:
+ while True:
+ results.append(super(CCLBackend, self).get_global_rank(group, rank))
+ rank += 1
+ except RuntimeError:
+ pass
+ if tuple(results) not in self.groups:
+ self._new_group(results, group)
+ return results
diff --git a/deepspeed/comm/config.py b/deepspeed/comm/config.py
index 138badebe5a9..1c441bb6bfe9 100644
--- a/deepspeed/comm/config.py
+++ b/deepspeed/comm/config.py
@@ -3,8 +3,8 @@
# DeepSpeed Team
-from pydantic import BaseModel
from .constants import *
+from ..pydantic_v1 import BaseModel
class CommsConfig(BaseModel):
diff --git a/deepspeed/constants.py b/deepspeed/constants.py
index 7ebc8f9983a5..30135f41b7b6 100644
--- a/deepspeed/constants.py
+++ b/deepspeed/constants.py
@@ -3,6 +3,7 @@
# DeepSpeed Team
+import os
from datetime import timedelta
#############################################
@@ -15,6 +16,6 @@
# (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1).
# To make an attempt at backwards compatibility with THD, we use an
# extraordinarily high default timeout, given that THD did not have timeouts.
-default_pg_timeout = timedelta(minutes=30)
+default_pg_timeout = timedelta(minutes=int(os.getenv("DEEPSPEED_TIMEOUT", default=30)))
INFERENCE_GENERIC_MODE = 'generic'
INFERENCE_SPECIALIZED_MODE = 'specialized'
diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py
index 70f1c0dbd5b7..1d5018aaa75b 100644
--- a/deepspeed/inference/config.py
+++ b/deepspeed/inference/config.py
@@ -5,10 +5,9 @@
import torch
import deepspeed
+from deepspeed.pydantic_v1 import Field, validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
-from pydantic import Field
-from pydantic import validator
from typing import Dict, Union
from enum import Enum
diff --git a/deepspeed/inference/quantization/utils.py b/deepspeed/inference/quantization/utils.py
index d47eb265c214..712abc384a44 100644
--- a/deepspeed/inference/quantization/utils.py
+++ b/deepspeed/inference/quantization/utils.py
@@ -105,20 +105,24 @@ def __init__(self, config: Dict, dtype: torch.dtype) -> None:
def dequantize(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) -> Tensor:
# Use customized CUDA quantization kernel if possible.
if self.config['group_size'] % 8 == 0 and \
- self.config['num_bits'] == 4 and \
+ (self.config['num_bits'] == 4 or self.config['num_bits'] == 8) and \
self.config['group_dim'] == len(tensor.shape) - 1 and \
self.dtype == torch.float16 and device == 'cuda':
last_dimension_size = self.config['group_size']
if self.config['num_bits'] == 4:
last_dimension_size = last_dimension_size // 2
- quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
- tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
- tensor.numel() // last_dimension_size, self.config['group_size'])
-
- shape = list(tensor.shape)
- if self.config['num_bits'] == 4:
+ quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
+ tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
+ tensor.numel() // last_dimension_size, self.config['group_size'])
+ shape = list(tensor.shape)
shape[-1] = shape[-1] * 2
+ elif self.config['num_bits'] == 8:
+ # last_dimension_size = last_dimension_size // 2
+ quantized_tensor = get_quantizer_cuda_module().dequantize_int8_to_half_experimental(
+ tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
+ tensor.numel() // last_dimension_size, self.config['group_size'])
+ shape = list(tensor.shape)
return quantized_tensor.reshape(shape)
diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py
index daf143919558..2e348de63454 100644
--- a/deepspeed/module_inject/auto_tp.py
+++ b/deepspeed/module_inject/auto_tp.py
@@ -11,7 +11,7 @@
from typing import Optional
import torch
from deepspeed import comm as dist
-from .layers import LinearAllreduce, LinearLayer
+from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
@@ -318,6 +318,11 @@ def _replace(self, child, name, conv_linear_layer):
del data
setattr(child, "replaced", True)
+ if name == "lm_head" or name == 'embed_out':
+ return LmHeadLinearAllreduce(
+ torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
+ child.bias if child.bias is None else torch.nn.parameter.Parameter(
+ child.bias.to(get_accelerator().current_device_name())), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group)
else:
@@ -436,3 +441,16 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
self.update_mp_params(child)
self._replace_module(child, name, class_name)
return r_module
+
+ def _replace_last_linear_module(self, r_module):
+ if hasattr(r_module, "lm_head"):
+ name = "lm_head"
+ child = r_module.lm_head
+ elif hasattr(r_module, "embed_out"):
+ name = "embed_out"
+ child = r_module.embed_out
+ else:
+ return r_module
+ if child.__class__ in self.linear_policies:
+ setattr(r_module, name, self.linear_policies[child.__class__](child, name, self.conv_linear_layer))
+ return r_module
diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py
index aa29651ec4cf..7a565560dec9 100644
--- a/deepspeed/module_inject/layers.py
+++ b/deepspeed/module_inject/layers.py
@@ -29,6 +29,36 @@ def forward(self, input):
return output
+class LmHeadLinearAllreduce(nn.Module):
+
+ def __init__(
+ self,
+ weight,
+ rank,
+ world_size,
+ bias=None,
+ mp_group=None,
+ ):
+ super(LmHeadLinearAllreduce, self).__init__()
+ self.weight = weight
+ self.bias = bias
+ self.mp_group = mp_group
+ self.rank = rank
+ self.world_size = world_size
+
+ def forward(self, input):
+ assert input.shape[
+ -1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]'
+ input_shard = input.shape[-1] // self.world_size
+ output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard],
+ self.weight.transpose(-1, -2))
+ if self.mp_group is not None:
+ dist.inference_all_reduce(output, group=self.mp_group)
+ if self.bias is not None:
+ output += self.bias
+ return output
+
+
class LinearLayer(nn.Module):
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py
index f9dd921b9ae1..8666372fa3f4 100644
--- a/deepspeed/module_inject/replace_module.py
+++ b/deepspeed/module_inject/replace_module.py
@@ -275,6 +275,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
_autotp.update_linear_policies()
# 4. Replace modules
+ if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears:
+ return _autotp._replace_last_linear_module(module)
return _autotp._replace_module(module)
def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
@@ -304,6 +306,13 @@ def set_lm_head(module):
if embedding_weight is not None and hasattr(module, "lm_head") and hasattr(
module.lm_head, "weight") and module.lm_head.weight.is_meta:
module.lm_head.weight = embedding_weight
+ # enable tensor parallel for the last linear
+ if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and not module.lm_head.weight.is_meta:
+ module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
+ elif hasattr(module, "embed_out") and hasattr(module.embed_out,
+ "weight") and not module.embed_out.weight.is_meta:
+ module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
+ return module
if checkpoint_dict is not None and not config.replace_with_kernel_inject:
# AutoTP shard loading
@@ -318,7 +327,7 @@ def set_lm_head(module):
checkpoint=checkpoint_file)
pbar.update(1)
gc.collect()
- set_lm_head(replaced_module)
+ replaced_module = set_lm_head(replaced_module)
else:
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
diff --git a/deepspeed/monitor/config.py b/deepspeed/monitor/config.py
index 2706764290fd..5a8ca6ecf5cd 100644
--- a/deepspeed/monitor/config.py
+++ b/deepspeed/monitor/config.py
@@ -3,7 +3,7 @@
# DeepSpeed Team
-from pydantic import root_validator
+from deepspeed.pydantic_v1 import root_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py
index fe19299fcb60..ddcabf0d29e5 100644
--- a/deepspeed/profiling/flops_profiler/profiler.py
+++ b/deepspeed/profiling/flops_profiler/profiler.py
@@ -722,9 +722,9 @@ def _upsample_flops_compute(*args, **kwargs):
flops = input.numel()
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
- flops * int(_prod(scale_factor))
+ flops *= int(_prod(scale_factor))
else:
- flops * scale_factor**len(input)
+ flops *= scale_factor**len(input)
return flops, 0
diff --git a/deepspeed/pydantic_v1.py b/deepspeed/pydantic_v1.py
new file mode 100644
index 000000000000..6aba072ad929
--- /dev/null
+++ b/deepspeed/pydantic_v1.py
@@ -0,0 +1,16 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""Pydantic v1 compatibility module.
+
+Pydantic v2 introduced breaking changes that hinder its adoption:
+https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to
+migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module
+as a pydantic-version-agnostic alias for pydantic's v1 API.
+"""
+
+try:
+ from pydantic.v1 import * # noqa: F401
+except ImportError:
+ from pydantic import * # noqa: F401
diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py
index 77407a52026a..108cb37b57fb 100644
--- a/deepspeed/runtime/activation_checkpointing/checkpointing.py
+++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py
@@ -270,6 +270,8 @@ def gather_partitioned_activations(tensors, device=None):
# don't need to do all_gather if model parallel is not enabled
if mp_group is None or mp_size == 1:
item = item.view(list(size.numpy()))
+ if device is not None:
+ item = item.to(device)
inputs.append(item)
continue
diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py
index 9c7a84d4841e..550af8fac057 100644
--- a/deepspeed/runtime/bf16_optimizer.py
+++ b/deepspeed/runtime/bf16_optimizer.py
@@ -365,7 +365,8 @@ def load_state_dict(self,
state_dict_list,
checkpoint_folder,
load_optimizer_states=True,
- load_from_fp32_weights=False):
+ load_from_fp32_weights=False,
+ load_serial=None):
if checkpoint_folder:
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
else:
diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py
index 0fb1372deac8..5522a8e79d69 100755
--- a/deepspeed/runtime/config_utils.py
+++ b/deepspeed/runtime/config_utils.py
@@ -9,7 +9,7 @@
import collections
import collections.abc
from functools import reduce
-from pydantic import BaseModel
+from deepspeed.pydantic_v1 import BaseModel
from deepspeed.utils import logger
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index b4c8ef56c701..8a8193ddd8f5 100644
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -3287,7 +3287,7 @@ def _get_zero_param_shapes(self):
# if we don't use it, we get parameters ordered incorrectly
if hasattr(self.optimizer, "round_robin_bit16_groups"):
bit16_groups = self.optimizer.round_robin_bit16_groups
- elif self.bfloat16_enabled() and not self.zero_optimization():
+ elif self.bfloat16_enabled() and hasattr(self.optimizer, "bf16_groups"):
bit16_groups = self.optimizer.bf16_groups
else:
bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage(
diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py
index 4e8a0faa0ae0..0c7c9f7a1090 100644
--- a/deepspeed/runtime/pipe/engine.py
+++ b/deepspeed/runtime/pipe/engine.py
@@ -386,7 +386,7 @@ def train_batch(self, data_iter=None):
# TODO: should return precisely what loss returned and allow others to be queried?
return self.agg_train_loss
- def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_output='avg'):
+ def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_output='avg', bcast_loss=True):
"""Evaluate the pipeline on a batch of data from ``data_iter``. The
engine will evaluate ``self.train_batch_size()`` total samples
collectively across all workers.
@@ -449,7 +449,7 @@ def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_o
if self.is_last_stage():
eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output)
- if compute_loss:
+ if compute_loss and (bcast_loss or self.monitor.enabled):
eval_output = self._bcast_pipe_scalar(eval_output)
if self.global_rank == 0 and self.monitor.enabled:
diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py
index 1fc11f0e46f5..35d60b5b3290 100644
--- a/deepspeed/runtime/zero/config.py
+++ b/deepspeed/runtime/zero/config.py
@@ -3,10 +3,10 @@
# DeepSpeed Team
-from pydantic import Field, validator
import sys
from typing import Optional
from enum import Enum
+from deepspeed.pydantic_v1 import Field, validator
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
from deepspeed.utils import logger
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py
index c3a6dc7af530..1bd79412d39f 100644
--- a/deepspeed/runtime/zero/offload_config.py
+++ b/deepspeed/runtime/zero/offload_config.py
@@ -3,9 +3,9 @@
# DeepSpeed Team
-from pydantic import Field, validator
from enum import Enum
from pathlib import Path
+from deepspeed.pydantic_v1 import Field, validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int
diff --git a/docs/_pages/deepspeed4science.md b/docs/_pages/deepspeed4science.md
index 6dd87ce996e2..b35351838f22 100755
--- a/docs/_pages/deepspeed4science.md
+++ b/docs/_pages/deepspeed4science.md
@@ -6,7 +6,18 @@ toc_label: "Contents"
toc_sticky: true
---
-In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. This page serves as an overview page for all technologies released (or to be released in the future) as part of the DeepSpeed4Science initiative, making it easier for scientists to shop for techniques they need. Details of the DeepSpeed4Science initiative can be found at [our website](https://deepspeed4science.ai/). For each technique we will introduce what is it for, when to use it, links to how to use it, and existing scientific applications of the techniques (we welcome users to contribute more showcases if you apply our techniques in your scientific research):
+In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. This page serves as an overview page for all technologies released (or to be released in the future) as part of the DeepSpeed4Science initiative, making it easier for scientists to shop for techniques they need. Details of the DeepSpeed4Science initiative can be found at [our website](https://deepspeed4science.ai/). For each technique we will introduce what is it for, when to use it, links to how to use it, and existing scientific applications of the techniques (we welcome users to contribute more showcases if you apply our techniques in your scientific research).
+
+To cite DeepSpeed4Science, please cite our [white paper](https://arxiv.org/abs/2310.04610):
+
+```
+@article{song2023deepspeed4science,
+ title={DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies},
+ author={Song, Shuaiwen Leon and Kruft, Bonnie and Zhang, Minjia and Li, Conglong and Chen, Shiyang and Zhang, Chengming and Tanaka, Masahiro and Wu, Xiaoxia and Rasley, Jeff and Awan, Ammar Ahmad and others},
+ journal={arXiv preprint arXiv:2310.04610},
+ year={2023}
+}
+```
* [2023/09] We are releasing two techniques: [DeepSpeed4Science large-scale training framework](#new-megatron-deepspeed-for-large-scale-ai4science-model-training), [DS4Sci_EvoformerAttention](#memory-efficient-evoformerattention-kernels) and their scientific applications in structural biology research.
diff --git a/docs/_tutorials/advanced-install.md b/docs/_tutorials/advanced-install.md
index c2b4c04cad1c..10197e62f681 100755
--- a/docs/_tutorials/advanced-install.md
+++ b/docs/_tutorials/advanced-install.md
@@ -61,6 +61,7 @@ Available `DS_BUILD` options include:
* `DS_BUILD_CCL_COMM` builds the communication collective libs
* `DS_BUILD_CPU_ADAM` builds the CPUAdam op
* `DS_BUILD_CPU_LION` builds the CPULion op
+* `DS_BUILD_EVOFORMER_ATTN` builds the EvoformerAttn op (from [Alphafold](https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/))
* `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex))
* `DS_BUILD_FUSED_LION` builds the FusedLion op
* `DS_BUILD_CPU_ADAGRAD` builds the CPUAdagrad op
@@ -71,7 +72,6 @@ Available `DS_BUILD` options include:
* `DS_BUILD_TRANSFORMER` builds the transformer op
* `DS_BUILD_TRANSFORMER_INFERENCE` builds the transformer-inference op
* `DS_BUILD_STOCHASTIC_TRANSFORMER` builds the stochastic transformer op
-* `DS_BUILD_UTILS` builds various optimized utilities
To speed up the build-all process, you can parallelize the compilation process with:
diff --git a/docs/index.md b/docs/index.md
index 14c131a9a22d..f6b3eb18ed1f 100755
--- a/docs/index.md
+++ b/docs/index.md
@@ -8,7 +8,7 @@ title: "Latest News"
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat).
* [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)]
-* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](/deepspeed4science/)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
+* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses) [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/japanese/README.md)]
@@ -137,6 +137,7 @@ comments.
25. Zhewei Yao, Reza Yazdani Aminabadi, Olatunji Ruwase, Samyam Rajbhandari, Xiaoxia Wu, Ammar Ahmad Awan, Jeff Rasley, Minjia Zhang, Conglong Li, Connor Holmes, Zhongzhu Zhou, Michael Wyatt, Molly Smith, Lev Kurilenko, Heyang Qin, Masahiro Tanaka, Shuai Che, Shuaiwen Leon Song, Yuxiong He. (2023) DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales [arXiv:2308.01320](https://arxiv.org/abs/2308.01320).
26. Xiaoxia Wu, Zhewei Yao, Yuxiong He. (2023) ZeroQuant-FP: A Leap Forward in LLMs Post-Training W4A8 Quantization Using Floating-Point Formats [arXiv:2307.09782](https://arxiv.org/abs/2307.09782)
27. Zhewei Yao, Xiaoxia Wu, Conglong Li, Minjia Zhang, Heyang Qin, Olatunji Ruwase, Ammar Ahmad Awan, Samyam Rajbhandari, Yuxiong He. (2023) DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention [arXiv:2309.14327](https://arxiv.org/pdf/2309.14327.pdf)
+28. Shuaiwen Leon Song, Bonnie Kruft, Minjia Zhang, Conglong Li, Shiyang Chen, Chengming Zhang, Masahiro Tanaka, Xiaoxia Wu, Jeff Rasley, Ammar Ahmad Awan, Connor Holmes, Martin Cai, Adam Ghanem, Zhongzhu Zhou, Yuxiong He, et al. (2023) DeepSpeed4Science Initiative: Enabling Large-Scale Scientific Discovery through Sophisticated AI System Technologies [arXiv:2310.04610](https://arxiv.org/abs/2310.04610)
# Videos
1. DeepSpeed KDD 2020 Tutorial
diff --git a/environment.yml b/environment.yml
index e55fe96e5a5a..28c298717d80 100644
--- a/environment.yml
+++ b/environment.yml
@@ -18,4 +18,4 @@ dependencies:
- certifi
- openssl
- python=3.10
- - pydantic<2.0.0
+ - pydantic
diff --git a/op_builder/async_io.py b/op_builder/async_io.py
index da511a0a8c9d..b55c821910b9 100644
--- a/op_builder/async_io.py
+++ b/op_builder/async_io.py
@@ -5,7 +5,6 @@
import distutils.spawn
import subprocess
-import torch
from .builder import OpBuilder
@@ -36,6 +35,7 @@ def cxx_args(self):
# -O0 for improved debugging, since performance is bound by I/O
CPU_ARCH = self.cpu_arch()
SIMD_WIDTH = self.simd_width()
+ import torch # Keep this import here to avoid errors when building DeepSpeed wheel without torch installed
TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[0:2])
if TORCH_MAJOR >= 2 and TORCH_MINOR >= 1:
CPP_STD = '-std=c++17'
diff --git a/op_builder/quantizer.py b/op_builder/quantizer.py
index ada80b8f3331..fd765b743de0 100644
--- a/op_builder/quantizer.py
+++ b/op_builder/quantizer.py
@@ -22,7 +22,7 @@ def sources(self):
'csrc/quantization/pt_binding.cpp',
'csrc/quantization/fake_quantizer.cu',
'csrc/quantization/quantize.cu',
- 'csrc/quantization/quantize_int4.cu',
+ 'csrc/quantization/quantize_intX.cu',
'csrc/quantization/dequantize.cu',
'csrc/quantization/swizzled_quantize.cu',
'csrc/quantization/quant_reduce.cu',
diff --git a/requirements/requirements-readthedocs.txt b/requirements/requirements-readthedocs.txt
index a6d7915e0ea5..fcd0ec5a9a6a 100644
--- a/requirements/requirements-readthedocs.txt
+++ b/requirements/requirements-readthedocs.txt
@@ -1,9 +1,9 @@
-autodoc_pydantic<2.0.0
+autodoc_pydantic
docutils<0.18
hjson
packaging
psutil
py-cpuinfo
-pydantic<2.0.0
+pydantic
torch
tqdm
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 8c5e76750573..6840d6dbcc98 100755
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -4,6 +4,6 @@ numpy
packaging>=20.0
psutil
py-cpuinfo
-pydantic<2.0.0
+pydantic
torch
tqdm
diff --git a/tests/unit/inference/quantization/test_int4_quantization.py b/tests/unit/inference/quantization/test_intX_quantization.py
similarity index 91%
rename from tests/unit/inference/quantization/test_int4_quantization.py
rename to tests/unit/inference/quantization/test_intX_quantization.py
index 56a5a7d48382..56df2b232d15 100644
--- a/tests/unit/inference/quantization/test_int4_quantization.py
+++ b/tests/unit/inference/quantization/test_intX_quantization.py
@@ -53,12 +53,11 @@ def quantization_test_helper(pre_quant_type: torch.dtype, num_bits: int):
assert mean_diff < 0.15 and max_diff < 0.5, f'Numeric error exceed threshold, mean diff {mean_diff} (threshold 0.15), max diff {max_diff} (threshold 0.5)'
-def zero3_post_init_quantization_test_helper(cpu_offload: bool, nvme_offload: bool):
+def zero3_post_init_quantization_test_helper(cpu_offload: bool, nvme_offload: bool, bits: int):
import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig
- def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool) -> Dict:
- bits = 4
+ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict:
GB = 1 << 30
ds_config = {
@@ -143,7 +142,7 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b
return ds_config
hf_config = AutoConfig.from_pretrained('facebook/opt-125m')
- ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload)
+ ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload, bits=bits)
input_ids = torch.ones(1, 16, dtype=torch.int32, device=device)
attention_mask = torch.ones(1, 16, dtype=torch.float32, device=device)
@@ -171,12 +170,11 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b
assert mean_diff < 0.4, f'Numeric error exceed threshold, relative error {mean_diff} (threshold 0.4)'
-def zero3_quantized_initialization_test_helper(cpu_offload: bool, nvme_offload: bool):
+def zero3_quantized_initialization_test_helper(cpu_offload: bool, nvme_offload: bool, bits: int):
import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig
- def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool) -> Dict:
- bits = 4
+ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict:
GB = 1 << 30
ds_config = {
@@ -223,7 +221,7 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b
return ds_config
hf_config = AutoConfig.from_pretrained('facebook/opt-125m')
- ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload)
+ ds_config = get_zero3_ds_config(hf_config=hf_config, cpu_offload=cpu_offload, nvme_offload=nvme_offload, bits=bits)
input_ids = torch.ones(1, 16, dtype=torch.int32, device=device)
attention_mask = torch.ones(1, 16, dtype=torch.float32, device=device)
@@ -249,16 +247,26 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b
assert mean_diff < 0.4, f'Numeric error exceed threshold, relative error {mean_diff} (threshold 0.4)'
-class TestQuantizedInt4(DistributedTest):
+@pytest.fixture(params=[4, 8], ids=["4bits", "8bits"])
+def quantization_bits(request):
+ return request.param
- def test_model_quantization(self):
+
+@pytest.fixture(params=[0, 1], ids=["0", "1"])
+def group_dim(request):
+ return request.param
+
+
+class TestQuantizedInt(DistributedTest):
+
+ def test_model_quantization(self, quantization_bits):
reset_random()
config = AutoConfig.from_pretrained('facebook/opt-125m')
with torch.no_grad():
model = OPTDecoderLayer(config).half().to(device)
- bits = 4
+ bits = quantization_bits
ds_config = {
'weight_quantization': {
@@ -307,7 +315,7 @@ def test_model_quantization(self):
assert type(model.self_attn.out_proj) is QuantizedLinear
@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
- def test_quantized_linear(self):
+ def test_quantized_linear(self, quantization_bits, group_dim):
reset_random()
layers = []
@@ -326,9 +334,9 @@ def test_quantized_linear(self):
'weight_quantization': {
'post_init_quant': {
'layer': {
- 'num_bits': 4,
+ 'num_bits': quantization_bits,
'group_size': 64,
- 'group_dim': 0,
+ 'group_dim': group_dim,
'symmetric': False
}
}
@@ -368,31 +376,31 @@ def test_half_int8_quantization(self):
quantization_test_helper(torch.float16, 8)
@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
- def test_zero3_int4_post_init_quant(self):
+ def test_zero3_int4_post_init_quant(self, quantization_bits):
reset_random()
- zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=False)
+ zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=False, bits=quantization_bits)
@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
- def test_zero3_int4_post_init_quant_cpu_offload(self):
+ def test_zero3_int4_post_init_quant_cpu_offload(self, quantization_bits):
reset_random()
- zero3_post_init_quantization_test_helper(cpu_offload=True, nvme_offload=False)
+ zero3_post_init_quantization_test_helper(cpu_offload=True, nvme_offload=False, bits=quantization_bits)
@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
def test_zero3_int4_post_init_quant_nvme_offload(self):
reset_random()
- zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=True)
+ zero3_post_init_quantization_test_helper(cpu_offload=False, nvme_offload=True, bits=4)
@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
- def test_zero3_int4_quantized_initialization(self):
+ def test_zero3_int4_quantized_initialization(self, quantization_bits):
reset_random()
- zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=False)
+ zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=False, bits=quantization_bits)
@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
- def test_zero3_int4_quantized_initialization_cpu_offload(self):
+ def test_zero3_int4_quantized_initialization_cpu_offload(self, quantization_bits):
reset_random()
- zero3_quantized_initialization_test_helper(cpu_offload=True, nvme_offload=False)
+ zero3_quantized_initialization_test_helper(cpu_offload=True, nvme_offload=False, bits=quantization_bits)
@pytest.mark.skipif(device == 'cpu', reason='CPU does support FP16 GEMM')
def test_zero3_int4_quantized_initialization_nvme_offload(self):
reset_random()
- zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=True)
+ zero3_quantized_initialization_test_helper(cpu_offload=False, nvme_offload=True, bits=4)
diff --git a/tests/unit/runtime/test_ds_config_model.py b/tests/unit/runtime/test_ds_config_model.py
index b9c67c9a30dd..87ea747cf423 100644
--- a/tests/unit/runtime/test_ds_config_model.py
+++ b/tests/unit/runtime/test_ds_config_model.py
@@ -6,8 +6,8 @@
import pytest
import os
import json
-from pydantic import Field, ValidationError
from typing import List
+from deepspeed.pydantic_v1 import Field, ValidationError
from deepspeed.runtime import config as ds_config
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
diff --git a/version.txt b/version.txt
index 9b40aa6c214f..bc859cbd6d99 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.10.4
+0.11.2