diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml
index f81e690e835b..ac19638e67de 100644
--- a/.github/workflows/hpu-gaudi2.yml
+++ b/.github/workflows/hpu-gaudi2.yml
@@ -39,13 +39,14 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
- image: vault.habana.ai/gaudi-docker/1.15.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest
+ image: vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
env:
PT_HPU_LAZY_MODE: 0
+ TORCHINDUCTOR_COMPILE_THREADS: 1
TEST_LIST: |
test_accelerator.py
test_autotuning.py
@@ -103,7 +104,7 @@ jobs:
- name: Check container state
run: |
ldd --version
- hl-smi
+ hl-smi -L
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
@@ -128,7 +129,7 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
export PT_HPU_LAZY_MODE=${PT_HPU_LAZY_MODE}
+ export TORCHINDUCTOR_COMPILE_THREADS=${TORCHINDUCTOR_COMPILE_THREADS}
TEST_LIST=$(echo "$TEST_LIST" | awk 'NF{printf "%s%s", (NR>1 ? " or " : ""), $0} END{if (NR>1) print ""}')
echo "TEST_LIST ${TEST_LIST}"
- echo "PT_HPU_LAZY_MODE ${PT_HPU_LAZY_MODE}"
pytest --verbose unit/ -k "${TEST_LIST}"
diff --git a/README.md b/README.md
index 5f990fd70d7d..304169b56777 100755
--- a/README.md
+++ b/README.md
@@ -15,6 +15,9 @@
## Latest News
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat).
+
+
+* [2024/08] [DeepNVMe: Improving DL Applications through I/O Optimizations](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-gds/README.md) [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-gds/japanese/README.md)]
* [2024/07] [DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/README.md) [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/japanese/README.md)]
* [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)]
* [2024/01] [DeepSpeed-FastGen: Introducing Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)
diff --git a/blogs/deepspeed-gds/japanese/README.md b/blogs/deepspeed-gds/japanese/README.md
new file mode 100644
index 000000000000..8d65d5225b16
--- /dev/null
+++ b/blogs/deepspeed-gds/japanese/README.md
@@ -0,0 +1,77 @@
+
+
+# DeepNVMe: I/O最適化による深層学習アプリケーションの高速化
+
+
+
+# はじめに
+
+深層学習(Deep Learning)は、言語、音声、ビデオ、マルチモーダルアプリケーションなどの重要なAIの応用領域において、かつてない進歩を続けています。この進歩の鍵となる要因は、モデルサイズ、シーケンス長、ハードウェア並列性などの複数の次元での劇的なスケーラビリティです。システムの観点から見ると、深層学習のスケーラビリティは計算、メモリ、通信、ストレージなどの重要なサブシステムに大きな負荷をかけます。しかし、既存の取り組みは、ストレージサブシステムの最適化はほとんど扱われておらず、データロード、モデルチェックポイント、オフロードなどのI/O操作が大規模な深層学習の主要なボトルネックとなっています。この問題に対処するために、DeepSpeedは一連のI/O最適化機能を「DeepNVMe」と呼ばれる形で提供します。
+
+DeepNVMeは、I/O操作の高速化とハードウェア要件の緩和によって、I/Oがボトルネックとなる深層学習アプリケーションのパフォーマンスと効率を向上させます。これを実現するために、Non-Volatile Memory Express(NVMe)やSSD、NVIDIA Magnum IO ``TM `` GPUDirect® Storage(GDS)などのストレージ技術を活用しています。このブログでは、マイクロベンチマークと推論アプリケーションの性能評価結果に基づいて、DeepNVMeの利点を示します。Azure NC96ads_A100_v4 VMで実施された実験では、DeepNVMeがGPUまたはCPUメモリへのデータ転送で利用可能なNVMe帯域幅を最大限に活用し、最大10GB/秒の読み取りと5GB/秒の書き込みを達成しました。
+
+# 背景
+
+永続ストレージへの高性能アクセスは、深層学習を含む多くのコンピューティングドメインで共通の課題です。これに対して、多くのハードウェアおよびソフトウェアソリューションが提案されています。DeepNVMeは、以下の3つのソリューションを基に構築されています。(1) NVMe SSD、(2) NVIDIA GDS、(3) Linux非同期I/O(libaio)。これらの技術について簡単に説明します。
+
+NVMe SSDは、現代のサーバーで主要な永続ストレージとして、従来の遅いハードディスクドライブ(HDD)に取って代わるフラッシュベースのストレージデバイスです。たとえば、Azure NC96ads_A100_v4 VMには4つのNVMe SSDが装備されており、それぞれが3.25 GB/秒の読み取り速度を持ち、RAID-0構成で組み合わせると理論上の合計読み取り帯域幅は13 GB/秒となります。NVIDIA GDSは、NVMeとGPUメモリ間の直接転送を可能にすることで、中間のCPUメモリ(バウンスバッファ)を使用する従来のアプローチの非効率を回避します。NVIDIA GDSは、CUDAバージョン11.4以上で利用可能です。最後に、libaioは、従来のI/Oスタックと比較して、NVMe SSDのような高速ストレージデバイスの性能をより引き出すためにLinuxに導入された非同期I/Oスタックです。
+
+# DeepNVMe: 深層学習のためのI/O最適化モジュール
+
+DeepNVMeは、以下の2つの主要な設計原則に基づいて開発されたPythonモジュールです。第一に、上記のストレージ技術を活用して、ノンブロッキングI/O操作、I/O操作の一括送信、個々のI/O操作の並列化、軽量なランタイムなどの最適化を実装しています。第二に、これらのI/O最適化をシンプルなPOSIXライクなインターフェースを通じて提供し、深層学習アプリケーションへの容易な統合を促進し、基盤となっている複雑な技術を直接扱うことなく、その性能を活用することを可能にします。
+
+# 評価
+
+実験は、Azure NC96ads_A100_v4 VMで実施されました。設定の詳細は表1の通りです。
+
+
+
+
+表1: 実験設定の詳細
+
+
+## マイクロベンチマーク
+
+評価には3つのベンチマークツールを使用しました。一つ目は、C言語で書かれた一般的なI/Oベンチマークツールであるfioです。次に、GDSパフォーマンスのベンチマークを行うためのNVIDIAのgdsioです。最後に、DeepNVMeとの容易な統合のために我々た作成したds_ioです。ds_ioは、深層学習アプリケーションで代表的に使用されるPythonで作成されています。
+
+## CPUバッファを使用したNVMeスケーリングによる高性能I/O
+
+最初のマイクロベンチマーク評価では、fioとds_ioを使用して、NVMeとCPUメモリ間で1GBのデータを転送するパフォーマンスを測定しました。これらの実験ではfioをlibaioバックエンドに設定しました。結果は図1の通りです。ここから、2つの点が読み取れます。第一に、DeepNVMeは、深層学習アプリケーションにおける性能改善を目指したものであるにも関わらず、このマイクロベンチマークでもfioに匹敵する高性能を示しています。第二に、DeepNVMeは、利用可能なNVMe帯域幅にほぼ線形にスケールし、10GB/秒の読み取りおよび5GB/秒の書き込み速度を達成しています。
+
+
+
+
+図1: DeepNVMeを使用したNVMeとCPUバッファ間のデータ転送のスケーリング
+
+
+## GPUバッファを使用したNVMeスケーリングによる高性能I/O
+
+二つ目のマイクロベンチマーク評価では、gdsioとds_ioを使用して、NVMeとGPUメモリ間で1GBのデータ転送のパフォーマンスを測定しました。この実験では、ds_ioを従来のバウンスバッファアプローチとより効率的なGDSアプローチの両方で設定します。結果は図2の通りです。ここから、次の3点が観察できます。第一にGDSを用いるケースで、従来のバウンスバッファアプローチと比較して、DeepNVMeは最大で37%のスピードアップを実現しています。第二に、DeepNVMeは、深層学習アプリケーションのために作成されたものであるにも関わらず、gdsioに匹敵する(時にはそれを上回る)高性能を示します。第三に、DeepNVMeは、GDSの有無にかかわらず、NVMe帯域幅を最大限に活用できます。GDSを使用した場合、DeepNVMeは最大9.6GB/秒の読み取りおよび5GB/秒の書き込み速度を達成し、GDSを使用しない場合は7GB/秒の読み取りおよび4GB/秒の書き込み速度を達成します。
+
+
+
+
+図2: DeepNVMeを使用したNVMeとGPUメモリ間のデータ転送のスケーリング
+
+
+## ZeRO-Inference: 生成AIパフォーマンス
+
+ZeRO-Inferenceは、モデルの重み(パラメータ)をCPUまたはNVMeメモリにオフロードすることで、大規模モデルの推論に必要なハードウェアコストを削減し、限られたハードウェア資源しかないユーザでも大規模モデルを活用できるようにするための技術です。ZeRO-Inferenceは、オフライン推論などのスループット指向のアプリケーションや、ハードウェア予算が限られているシナリオに適しています。DeepNVMeのNVMeオフロードのパフォーマンスを評価するために、トークン生成ワークロードを使用します。
+
+## NVMeスケーリングによる高性能オフロード
+
+LLAMA3-70Bモデルの推論を単一のNVIDIA A100-80GBで、プロンプト長512、生成長32、バッチサイズ96で実行し、生成スループットを測定します。NVMe SSDの数を1から4までスケーリングし、GDSの有無でZeRO-Inferenceの結果を図3に示します。この結果から、2つの観察ができます。第一に、GDSはバウンスバッファアプローチと比較して一貫して優れたパフォーマンスを提供し、トークン生成を10-18%高速化します。第二に、DeepNVMeは、GDSの有無にかかわらず、利用可能なNVMe帯域幅にスケールします。4つのNVMe SSDを使用する場合、DeepNVMeはGDSを使用して1秒あたり7トークン、GDSを使用しない場合は1秒あたり6トークンの生成スループットを達成します。プロファイリング結果は、DeepNVMeがより多くのNVMe帯域幅で引き続きスケールし、生成アプリケーションのパフォーマンスを低コストで向上できることを示しています。
+
+
+
+
+図3: DeepNVMeを使用したLLAMA3-70Bトークン生成パフォーマンスのNVMeオフロードによるスケーリング
+
+
+# まとめ
+
+このブログ記事では、深層学習のスケーラビリティにおいて主要なボトルネックとなるI/O操作を最適化する、DeepNVMeを紹介しました。DeepNVMeは、NVMe SSDやNVIDIA GDSなどのストレージ技術に基づいた最適化を通じて、永続ストレージと深層学習アプリケーションのデータ転送を高速かつ効率的に実現します。Azure NC96ads_A100_v4 VMでの単一A100-80GB GPUを使用したLLAMA3-70Bトークン生成において、DeepNVMeを使用することで、NVMeオフロードで最大7トークン/秒の生成スループットを達成しました。DeepNVMeはオープンソース化され、DeepSpeedバージョン[0.15.0](https://github.com/microsoft/DeepSpeed/releases/tag/v0.15.0).以上で利用可能です。今後のブログでは、モデルチェックポイントやデータロードなどの他のI/Oがボトルネックとなる深層学習アプリケーションに対するDeepNVMeの改善について報告します。
+
+# 謝辞
+
+この成果は、MicrosoftとNVIDIAの協力によるものです。MicrosoftからはJoe Mayer、Martin Cai、Olatunji Ruwase、NVIDIAからはKiran Modukuri、Vahid Noormofidi、Sourab Gupta、Sandeep Joshiが貢献しました。
diff --git a/csrc/aio/common/deepspeed_aio_common.cpp b/csrc/aio/common/deepspeed_aio_common.cpp
index 0f2895dfa328..a65cc500cc82 100644
--- a/csrc/aio/common/deepspeed_aio_common.cpp
+++ b/csrc/aio/common/deepspeed_aio_common.cpp
@@ -301,9 +301,8 @@ int regular_read(const char* filename, std::vector& buffer)
} while (r > 0);
if (read_bytes != num_bytes) {
- std::cerr << "read error "
- << " read_bytes (read) = " << read_bytes << " num_bytes (fstat) = " << num_bytes
- << std::endl;
+ std::cerr << "read error " << " read_bytes (read) = " << read_bytes
+ << " num_bytes (fstat) = " << num_bytes << std::endl;
}
assert(read_bytes == num_bytes);
close(fd);
diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp
index 387b713f2bfc..0556f5aa8168 100644
--- a/csrc/aio/py_lib/deepspeed_py_aio.cpp
+++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp
@@ -72,9 +72,8 @@ int deepspeed_py_aio_write(const torch::Tensor& buffer,
const std::chrono::duration fn_time =
std::chrono::high_resolution_clock::now() - start_time;
- std::cout << "Elapsed time(usec): "
- << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6
- << std::endl;
+ std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
+ << " call = " << fn_time.count() * 1e6 << std::endl;
return 0;
}
@@ -118,8 +117,7 @@ int deepspeed_py_aio_read(torch::Tensor& buffer,
const std::chrono::duration fn_time =
std::chrono::high_resolution_clock::now() - start_time;
- std::cout << "Elapsed time(usec): "
- << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6
- << std::endl;
+ std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
+ << " call = " << fn_time.count() * 1e6 << std::endl;
return 0;
}
diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
index c21e92de9449..23ddabe260d4 100644
--- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
+++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
@@ -93,9 +93,8 @@ int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, co
if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); }
const std::chrono::duration fn_time =
std::chrono::high_resolution_clock::now() - start_time;
- std::cout << "Elapsed time(usec): "
- << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6
- << std::endl;
+ std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
+ << " call = " << fn_time.count() * 1e6 << std::endl;
return 0;
}
@@ -128,9 +127,8 @@ int deepspeed_aio_handle_t::write(const torch::Tensor& buffer,
const std::chrono::duration fn_time =
std::chrono::high_resolution_clock::now() - start_time;
- std::cout << "Elapsed time(usec): "
- << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6
- << std::endl;
+ std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
+ << " call = " << fn_time.count() * 1e6 << std::endl;
return 0;
}
diff --git a/csrc/aio/py_lib/deepspeed_py_copy.cpp b/csrc/aio/py_lib/deepspeed_py_copy.cpp
index 8a59107dd347..c597b91d05c9 100644
--- a/csrc/aio/py_lib/deepspeed_py_copy.cpp
+++ b/csrc/aio/py_lib/deepspeed_py_copy.cpp
@@ -10,7 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include "deepspeed_py_copy.h"
#include
-#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
+#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#if defined(__AVX512__) or defined(__AVX256__)
union AVX_Data {
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h
index 2a4300c5cac1..c102234a4dfb 100644
--- a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h
+++ b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h
@@ -125,11 +125,10 @@ struct CheckArch {
std::cerr << #PTR " is not correctly aligned\n"; \
return false; \
}
-#define EVOFORMER_CHECK(COND, ERR) \
- if (!(COND)) { \
- std::cerr << "[Evoformer Attention]" \
- << "'" #COND "' failed: " << ERR << "\n"; \
- return false; \
+#define EVOFORMER_CHECK(COND, ERR) \
+ if (!(COND)) { \
+ std::cerr << "[Evoformer Attention]" << "'" #COND "' failed: " << ERR << "\n"; \
+ return false; \
}
#endif
diff --git a/csrc/includes/simd.h b/csrc/includes/simd.h
index f5bfb45dd2e2..a205026ec7c1 100644
--- a/csrc/includes/simd.h
+++ b/csrc/includes/simd.h
@@ -27,7 +27,7 @@ inline void writeAs(void* dst, const T& val)
std::memcpy(dst, &val, sizeof(T));
}
-#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
+#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#if defined(__AVX512__)
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
diff --git a/csrc/xpu/includes/simd.h b/csrc/xpu/includes/simd.h
old mode 100755
new mode 100644
index f77568be7835..097e2d8585cc
--- a/csrc/xpu/includes/simd.h
+++ b/csrc/xpu/includes/simd.h
@@ -13,7 +13,7 @@
#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__)
-#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
+#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#if defined(__AVX512__)
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
diff --git a/csrc/xpu/includes/type_shim.h b/csrc/xpu/includes/type_shim.h
index fa41757c895b..1897afd1fea2 100644
--- a/csrc/xpu/includes/type_shim.h
+++ b/csrc/xpu/includes/type_shim.h
@@ -82,11 +82,11 @@
}
template
-__inline__ __attribute__((always_inline)) T reduce_block_into_lanes(
- T* x,
- T val,
- int lanes = 1,
- bool share_result = false) // lanes is intended to be <= 32.
+__inline__ __attribute__((always_inline)) T
+reduce_block_into_lanes(T* x,
+ T val,
+ int lanes = 1,
+ bool share_result = false) // lanes is intended to be <= 32.
{
auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>();
int tid = item_ct1.get_local_id(2) + item_ct1.get_local_id(1) * item_ct1.get_local_range(2);
diff --git a/deepspeed/linear/__init__.py b/deepspeed/linear/__init__.py
index a27f1c3eaee7..9931a95a0a40 100644
--- a/deepspeed/linear/__init__.py
+++ b/deepspeed/linear/__init__.py
@@ -5,3 +5,4 @@
from .optimized_linear import OptimizedLinear
from .config import LoRAConfig, QuantizationConfig
+from .context_manager import Init, init_lora
diff --git a/deepspeed/linear/config.py b/deepspeed/linear/config.py
index ae9050a3c92b..2632ce7de9c4 100644
--- a/deepspeed/linear/config.py
+++ b/deepspeed/linear/config.py
@@ -3,7 +3,8 @@
# DeepSpeed Team
-from dataclasses import dataclass
+from dataclasses import dataclass, field
+from typing import List
@dataclass
@@ -17,10 +18,19 @@ class LoRAConfig:
base_weight_sharding (int): The degree to which the base weights are sharded,
should typically be set to the data-parallel world size to maximize the memory
reduction benefits. Defaults to 1, which means this feature is disabled.
+ offload (bool): offload frozen parameters to cpu when not in use
+ offload_ratio (float): ratio of parameters to offload to cpu when not in use
+ delay_lora_init (bool): initialize lora parameters at time of model init or allow manual init later
+ target_mods (str): target module names to apply LoRA to, defaults to llama-3.1 arch
"""
lora_r: int = 64
lora_alpha: float = 16.
base_weight_sharding: int = 1
+ offload: bool = False
+ offload_ratio: float = 0.0
+ delay_lora_init: bool = False
+ target_mods: List[str] = field(
+ default_factory=lambda: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'])
@dataclass
diff --git a/deepspeed/linear/context_manager.py b/deepspeed/linear/context_manager.py
new file mode 100644
index 000000000000..204fa0fe9c1d
--- /dev/null
+++ b/deepspeed/linear/context_manager.py
@@ -0,0 +1,90 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .optimized_linear import LoRAOptimizedLinear, OptimizedLinear
+
+import torch
+
+try:
+ import transformers
+except ImportError:
+ transformers = None
+
+
+def init_lora(model):
+ model.requires_grad_(False)
+ for m in model.modules():
+ if isinstance(m, LoRAOptimizedLinear):
+ m.init_lora()
+
+
+class Init(object):
+ """
+ Init context wrapper similar in style to zero.Init. Allows for injecting OptimizedLinear during model
+ construction which will shard base weights and reduce overall memory usage during model init. Primarily
+ useful when initializing a model via transformers.AutoModelForCausalLM.
+
+ Example usage:
+ lora_config = deepspeed.linear.LoRAConfig(..)
+ quant_config = deepspeed.linear.QuantizationConfig(..)
+ with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config):
+ model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-405B")
+
+ """
+
+ def __init__(self, lora_config=None, quant_config=None):
+ self._orig_nn_linear = torch.nn.Linear
+ self._orig_causallm_pretrained = None
+ if transformers != None:
+ self._orig_causallm_pretrained = transformers.AutoModelForCausalLM.from_pretrained
+ self._orig_causallm_config = transformers.AutoModelForCausalLM.from_config
+ self.lora_config = lora_config
+ self.quant_config = quant_config
+ self._post_init_complete = False
+
+ def __enter__(self):
+
+ class OptLinearWrapper:
+ _orig_nn_linear = self._orig_nn_linear
+ _lora_config = self.lora_config
+ _quant_config = self.quant_config
+
+ def __new__(self, *args, **kwargs):
+ self._lora_config.delay_lora_init = True
+ kwargs['lora_config'] = self._lora_config
+ kwargs['quantization_config'] = self._quant_config
+ kwargs['linear_cls'] = self._orig_nn_linear
+ return OptimizedLinear(*args, **kwargs)
+
+ def _model_init(model):
+ if self.lora_config != None:
+ init_lora(model)
+ self._post_init_complete = True
+ return model
+
+ # ensures non-lora params are frozen and lora weights are initialized
+ def from_pretrained(*args, **kwargs):
+ model = self._orig_causallm_pretrained(*args, **kwargs)
+ return _model_init(model)
+
+ def from_config(*args, **kwargs):
+ model = self._orig_causallm_config(*args, **kwargs)
+ return _model_init(model)
+
+ torch.nn.Linear = OptLinearWrapper
+ if transformers != None:
+ transformers.AutoModelForCausalLM.from_pretrained = from_pretrained
+ transformers.AutoModelForCausalLM.from_config = from_config
+
+ def __exit__(self, *args, **kwargs):
+ torch.nn.Linear = self._orig_nn_linear
+ if not self._post_init_complete:
+ print('WARNING: For some reason LoRA modules are not initialized, this is usually done automatically '
+ 'if using transformers via (AutoModelForCausalLM from_pretrained/from_config). '
+ 'You must call `init_lora` on each module in order to use DeepSpeed LoRA, otherwise '
+ 'you will error out during runtime.')
+ else:
+ transformers.AutoModelForCausalLM.from_pretrained = self._orig_causallm_pretrained
+ transformers.AutoModelForCausalLM.from_config = self._orig_causallm_config
diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py
index e982785a8122..3720196aa255 100644
--- a/deepspeed/linear/optimized_linear.py
+++ b/deepspeed/linear/optimized_linear.py
@@ -40,7 +40,9 @@ def __new__(self,
bias: bool = False,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
- dtype=torch.bfloat16):
+ device=None,
+ dtype=torch.bfloat16,
+ linear_cls=nn.Linear):
if quantization_config is not None and not is_dataclass(quantization_config):
raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}")
@@ -48,7 +50,7 @@ def __new__(self,
raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}")
if lora_config is None and quantization_config is None:
# Everything disabled, fall back to normal nn.Linear
- self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype)
+ self = linear_cls(input_dim, output_dim, bias=bias, dtype=dtype, device=device)
elif lora_config:
# lora enabled, quantization may or may not be
@@ -57,7 +59,9 @@ def __new__(self,
bias=bias,
lora_config=lora_config,
quantization_config=quantization_config,
- dtype=dtype)
+ dtype=dtype,
+ device=device,
+ linear_cls=linear_cls)
elif quantization_config:
# only quantization enabled, no lora
@@ -78,57 +82,121 @@ def __init__(self,
lora_config: LoRAConfig = None,
quantization_config: QuantizationConfig = None,
device=None,
- dtype=torch.bfloat16):
+ dtype=torch.bfloat16,
+ linear_cls=nn.Linear):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.bias = bias
self.lora_config = lora_config
self.quantization_config = quantization_config
- device = get_accelerator().current_device_name() if device is None else device
+ self.device = get_accelerator().current_device_name() if device is None else device
+ self.linear_cls = linear_cls
+ self.dtype = dtype
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"
-
+ assert not self.bias, "bias=True is not supported by LoRAOptimizedLinear"
self.zero_shards = self.lora_config.base_weight_sharding
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
- w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype))
- torch.nn.init.xavier_uniform_(w)
+ if self.zero_shards > 1:
+ assert self.zero_shards == dist.get_world_size(
+ ), "base weight sharding is only supported across world size"
+ w = torch.nn.Parameter(torch.empty(self.output_dim * self.sharded_weight_size, dtype=dtype),
+ requires_grad=False)
+ else:
+ w = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=dtype), requires_grad=False)
+ torch.nn.init.xavier_uniform_(w.reshape(self.sharded_weight_size, self.output_dim))
if self.quantization_config is not None:
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
- self.base_weight = QuantizedParameter(w, quantization_config=quantization_config)
+ self.weight = QuantizedParameter(w, quantization_config=quantization_config)
else:
- self.base_weight = w
+ self.weight = w
+
+ self.disabled = False
+ self._initialized = False
+ if not self.lora_config.delay_lora_init:
+ self.init_lora()
+
+ def disable(self):
+ self.disabled = True
+ self.weight = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=self.dtype),
+ requires_grad=False)
+
+ def init_lora(self):
+ if self.disabled:
+ return
+
+ if self.quantization_config is not None:
+ # ensure quant-param wasn't stripped, in some cases transformers will do this during model init
+ if not isinstance(self.weight, QuantizedParameter):
+ self.weight = QuantizedParameter(self.weight, quantization_config=self.quantization_config)
+
+ self._initialized = True
+ self.weight.requires_grad = False
- self.base_weight.requires_grad = False
+ # Mark base weight to prevent broadcast and ensure proper offload behavior
+ self.weight.ds_optim_param = True
+
+ self.lora_scaling_factor = self.lora_config.lora_alpha / self.lora_config.lora_r
- # Use RS lora for now.
- self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r)
# Keeping lora weights in bf16 precision for ease of training.
- self.lora_weight_1 = nn.Linear(self.input_dim,
- self.lora_config.lora_r,
- bias=self.bias,
- device=device,
- dtype=dtype)
- self.lora_weight_2 = nn.Linear(self.lora_config.lora_r,
- self.output_dim,
- bias=self.bias,
- device=device,
- dtype=dtype)
+ self.lora_weight_1 = self.linear_cls(self.input_dim,
+ self.lora_config.lora_r,
+ bias=self.bias,
+ device=self.device,
+ dtype=self.dtype)
+ self.lora_weight_2 = self.linear_cls(self.lora_config.lora_r,
+ self.output_dim,
+ bias=self.bias,
+ device=self.device,
+ dtype=self.dtype)
+
+ # initialize "A" with kaiming uniform and "B" with zeros following this
+ # https://github.com/huggingface/peft/blob/62122b5add8d6892f70c82eaef2147a6ba33b90b/src/peft/tuners/lora/layer.py#L155
+ nn.init.kaiming_uniform_(self.lora_weight_1.weight, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_weight_2.weight)
self.lora_weight_1.weight.requires_grad = True
self.lora_weight_2.weight.requires_grad = True
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
+ error_msgs):
+ if not any([target in prefix for target in self.lora_config.target_mods]):
+ # module does not match any target_mods, we must revert to normal nn.Linear via disable
+ self.disable()
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
+ unexpected_keys, error_msgs)
+
+ if self.zero_shards > 1:
+ if not dist.is_initialized():
+ raise RuntimeError(
+ "attempting to use optimized linear base weight sharding but torch-distributed is not initialized, please init first."
+ )
+ rank = dist.get_rank()
+ shape_local = self.output_dim * self.sharded_weight_size
+ base_weight_name = f"{prefix}weight"
+ incoming_param = state_dict[base_weight_name]
+ state_dict[base_weight_name] = incoming_param.flatten().narrow(0, rank * shape_local, shape_local)
+
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
+ error_msgs)
+
def full_weight(self):
- # This assumes weights are evenly sharded across gpus. which might not be correct.
- # in that case, we should flatten before all_gather.
- local_weight = self.base_weight.dequantized() if isinstance(self.base_weight,
- QuantizedParameter) else self.base_weight
- tensor_list = [
- torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype)
- for _ in range(self.zero_shards)
- ]
- dist.all_gather(tensor_list, local_weight)
- weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1))
- return weight
+ base_weight = self.weight
+ if getattr(base_weight, 'ds_offload', False):
+ # move to gpu so we can dequant and all-gather
+ assert base_weight.device == torch.device('cpu'), \
+ f"expected base weight on cpu but found {base_weight.device}"
+ base_weight.offload(revert=True)
+ local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight
+ base_weight.offload()
+ else:
+ local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight
+
+ tensor_out = torch.empty(self.output_dim * self.input_dim,
+ dtype=local_weight.dtype,
+ device=local_weight.device)
+ dist.all_gather_into_tensor(tensor_out, local_weight)
+ return tensor_out.reshape(self.output_dim, self.input_dim)
def linear_without_F_linear(self, input, weight):
output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
@@ -136,14 +204,18 @@ def linear_without_F_linear(self, input, weight):
return output
def forward(self, input_tensor):
+ if self.disabled:
+ return F.linear(input_tensor, self.weight)
+ assert self._initialized, "init_lora was never called, please initialize before proceeding"
+
# Gather the sharded base weight
if self.zero_shards > 1:
with torch.no_grad():
base_weight = self.full_weight()
elif self.quantization_config:
- base_weight = self.base_weight.dequantized()
+ base_weight = self.weight.dequantized()
else:
- base_weight = self.base_weight
+ base_weight = self.weight
base_weight_output = F.linear(input_tensor, base_weight)
lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py
index 8e4f23dfba89..70fabea845ba 100644
--- a/deepspeed/linear/quantization.py
+++ b/deepspeed/linear/quantization.py
@@ -75,6 +75,13 @@ def dequantized(self) -> torch.Tensor:
q_mantisa_bits=self.quantization_config.mantissa_bits)
return self.data
+ def offload(self, revert=False):
+ if getattr(self, 'ds_offload', False):
+ if revert:
+ self.data = self.to(get_accelerator().current_device_name())
+ else:
+ self.data = self.to('cpu')
+
def __getstate__(self):
state = self.__dict__
state["data"] = self.data
@@ -104,7 +111,9 @@ def __copy__(self):
return new_instance
def cuda(self, device=None, non_blocking=False):
- return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
+ device = "cuda" if device is None else device
+ self.quantizer.to(device, non_blocking=non_blocking)
+ return self.to(device, non_blocking=non_blocking)
def to(self, *args, **kwargs):
"""
@@ -112,6 +121,7 @@ def to(self, *args, **kwargs):
quantize it.
"""
tensor = super().to(*args, **kwargs)
+ self.quantizer.to(*args, **kwargs)
self._ensure_quantized(tensor)
return tensor
diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py
index 170954e0cf71..edd4ef57302c 100644
--- a/deepspeed/ops/fp_quantizer/quantize.py
+++ b/deepspeed/ops/fp_quantizer/quantize.py
@@ -91,6 +91,13 @@ def quantize(self,
return out
+ def to(self, *args, **kwargs):
+ # Intermediate tensors may need to be moved to different devices
+ if hasattr(self, 'input_q'):
+ self.input_q = self.input_q.to(*args, **kwargs)
+ if hasattr(self, 'scale'):
+ self.scale = self.scale.to(*args, **kwargs)
+
def get_scales(self):
return fp_quant_module.get_scales(self.scale, self.num_groups)
diff --git a/deepspeed/runtime/eigenvalue.py b/deepspeed/runtime/eigenvalue.py
index df63854dd1ca..36300eb904dd 100755
--- a/deepspeed/runtime/eigenvalue.py
+++ b/deepspeed/runtime/eigenvalue.py
@@ -7,6 +7,7 @@
from deepspeed.utils import log_dist
import numpy as np
import logging
+from deepspeed.utils.torch import required_torch_version
class Eigenvalue(object):
@@ -36,12 +37,15 @@ def __init__(self,
ranks=[0])
# Replace all nan/pos-inf/neg-inf to zero
- # TODO: Pytorch new version may add this function, replace this one by then.
def nan_to_num(self, x):
- device = x.device
- x = x.cpu().numpy()
- x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
- return torch.from_numpy(x).to(device)
+ if required_torch_version(min_version=1.8):
+ return torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
+ else:
+ # Fallback to numpy based implementation for backwards-compatibility with PyTorch 1.7 or older versions.
+ device = x.device
+ x = x.cpu().numpy()
+ x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
+ return torch.from_numpy(x).to(device)
def normalize(self, v):
norm_squared = self.inner_product(v, v)
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index d40141132aaf..1c74c0c735a0 100644
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -35,6 +35,8 @@
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
+from deepspeed.linear.optimized_linear import LoRAOptimizedLinear
+
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
@@ -326,6 +328,8 @@ def __init__(self,
self.sparse_tensor_module_names.add(name + ".weight")
logger.info("Will convert {} to sparse tensor during training".format(name))
+ self._optimized_linear_offload_setup()
+
self.save_non_zero_checkpoint = False
self.save_zero_checkpoint = False
if not isinstance(self.optimizer, DeepSpeedZeRoOffload):
@@ -363,6 +367,43 @@ def __init__(self,
self._is_compiled = False
+ def _optimized_linear_offload_setup(self):
+ self.optimized_linear_base_weight_sharding = False
+ self.optimized_linear_lora_enabled = False
+ offload_ratio = None
+ for _, module in self.module.named_modules():
+ if isinstance(module, LoRAOptimizedLinear):
+ self.optimized_linear_lora_enabled = True
+ offload_ratio = None
+ if offload_ratio is not None:
+ assert offload_ratio == module.lora_config.offload_ratio, \
+ "all lora_config offload ratios should be the same across the model"
+ offload_ratio = module.lora_config.offload_ratio
+ if module.zero_shards > 1:
+ # set attr so checkpoint saving can handle BWS properly
+ self.optimized_linear_base_weight_sharding = True
+
+ if offload_ratio is None:
+ # Nothing enabled, do nothing
+ return
+
+ total_params = 0
+ for _, p in self.module.named_parameters():
+ if hasattr(p, 'ds_optim_param'):
+ total_params += p.numel()
+
+ offload_limit = total_params * offload_ratio
+ logger.info(f'offloading {offload_ratio*100}% of eligible params, specifically {offload_limit} params')
+ total_offloaded = 0
+ for _, p in self.module.named_parameters():
+ if hasattr(p, 'ds_optim_param'):
+ if total_offloaded < offload_limit:
+ total_offloaded += p.numel()
+ p.ds_offload = True
+ p.offload()
+ else:
+ p.ds_offload = False
+
def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
self.optimizer.destroy()
@@ -1054,9 +1095,12 @@ def _broadcast_model(self):
def is_replicated(p):
if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE:
return False
+ elif hasattr(p, 'ds_optim_param'):
+ # do not broadcast OptimizedLinear parameters, they are unique per base weight shard
+ return False
return True
- for p in self.module.parameters():
+ for n, p in self.module.named_parameters():
# Broadcast the model for different parameters
if is_moe_param(p):
if torch.is_tensor(p) and is_replicated(p):
diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py
index f17cfa883cc6..e809fe1118b5 100644
--- a/deepspeed/sequence/layer.py
+++ b/deepspeed/sequence/layer.py
@@ -12,48 +12,76 @@
from deepspeed.accelerator import get_accelerator
-def post_all2all(transpose, res_shape):
+def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):
def post_func(input):
- if transpose:
- input = input.transpose(0, 2).contiguous()
- input = input.reshape(res_shape)
- return input
+ if batch_dim_idx == 0:
+ # b, s, n, h
+ if scatter_idx < 2:
+ output = input.permute(1, 2, 0, 3, 4).contiguous()
+ output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head,
+ head_dim).contiguous()
+ else:
+ output = input.permute(1, 0, 2, 3, 4).contiguous()
+ output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size,
+ head_dim).contiguous()
+ else:
+ # s, b, n, h
+ if scatter_idx < 2:
+ output = input.permute(1, 2, 0, 3, 4).contiguous()
+ output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head,
+ head_dim).contiguous()
+ else:
+ output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous()
+ return output
return post_func
-def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, handle=None, type=None):
+def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
seq_world_size = dist.get_world_size(group)
- inp_shape = list(input.shape)
- inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
+ if batch_dim_idx == 0:
+ # b, s, n, h
+ if scatter_idx < 2:
+ bs, global_seq_len, num_local_head, head_dim = input.shape
+ input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head,
+ head_dim]).contiguous()
+ input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
+ else:
+ bs, local_seq_len, num_total_head, head_dim = input.shape
+ assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
+ input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size,
+ head_dim]).contiguous()
+ input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
+ else:
+ # s, b, n, h
+ if scatter_idx < 2:
+ global_seq_len, bs, num_local_head, head_dim = input.shape
+ input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head,
+ head_dim]).contiguous()
+ else:
+ local_seq_len, bs, num_total_head, head_dim = input.shape
+ assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
+ input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size,
+ head_dim]).contiguous()
+ input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
+
if scatter_idx < 2:
- input_t = input.reshape(
- [seq_world_size, inp_shape[scatter_idx]] + \
- inp_shape[scatter_idx + 1:]
- ).contiguous()
+ post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head,
+ head_dim)
else:
- # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
- input_t = input.reshape(
- [-1, seq_world_size, inp_shape[scatter_idx]] + \
- inp_shape[scatter_idx + 1:]
- ).transpose(0, 1).contiguous()
+ post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head,
+ head_dim)
output = torch.empty_like(input_t)
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)
- res_shape=( inp_shape[: gather_idx] + \
- [inp_shape[gather_idx] * seq_world_size,] + \
- inp_shape[gather_idx + 1:])
- transpose = True if scatter_idx < 2 else False
- post_all2all_fun = post_all2all(transpose, res_shape)
-
if async_op:
if type in ('dq', 'dk'):
handle[type + '_work'] = work
handle[type + '_grad'] = output
handle[type + '_post_all2all_func'] = post_all2all_fun
- return output.view(res_shape)
+ return output
res = post_all2all_fun(output)
return res
@@ -67,6 +95,7 @@ def forward(ctx: Any,
input: Tensor,
scatter_idx: int,
gather_idx: int,
+ batch_dim_idx: int,
stream=None,
handle=None,
type=None,
@@ -77,14 +106,15 @@ def forward(ctx: Any,
ctx.stream = stream
ctx.handle = handle
ctx.type = type
+ ctx.batch_dim_idx = batch_dim_idx
if ctx.handle is None:
- res = single_all_to_all(input, scatter_idx, gather_idx, group, False)
+ res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
else:
# overlap communication path
if not is_fwd and type == 'o':
assert ctx.stream != None
- res = single_all_to_all(input, scatter_idx, gather_idx, group, False)
+ res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
get_accelerator().current_stream().wait_stream(ctx.stream)
del ctx.stream.activation_buffer_list
# The computation of d o_weight can overlap with the communication of d o_input
@@ -92,15 +122,15 @@ def forward(ctx: Any,
elif not is_fwd and type in ('q', 'k'):
# Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv
type = 'd' + type
- res = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type)
+ res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, True, handle, type)
elif is_fwd and type in ('q', 'k'):
# Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v
type = 'fwd_' + type
- res = single_all_to_all(input, scatter_idx, gather_idx, group, False, handle, type)
+ res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False, handle, type)
else:
- res = single_all_to_all(input, scatter_idx, gather_idx, group, False)
+ res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
return res
@@ -108,8 +138,8 @@ def forward(ctx: Any,
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (None,
- _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream, ctx.handle,
- ctx.type, False), None, None, None, None, None, None)
+ _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx,
+ ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None)
class DistributedAttention(torch.nn.Module):
@@ -148,13 +178,14 @@ def layer_sync(self, layer):
if self.sp_overlap_comm and hasattr(layer, 'done_event'):
self.dafult_stream.wait_event(layer.done_event)
- def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, batch_dim_idx: int, *args: Any, **kwargs) -> Tensor:
""" forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
+ batch_dim_idx (int): indicating which dim is batch
args: other args
Returns:
@@ -179,15 +210,15 @@ def pre_hook_fun(grad):
return pre_hook_fun
self.layer_sync(query)
- query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, None,
+ query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'q')
self.layer_sync(key)
- key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, None, self.overlap_handles,
- 'k')
+ key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
+ self.overlap_handles, 'k')
if self.sp_overlap_comm:
self.dafult_stream.wait_stream(self.sp_stream)
- value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, None,
+ value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'v')
if self.sp_overlap_comm:
@@ -205,8 +236,8 @@ def pre_hook_fun(grad):
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
- output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream,
- self.overlap_handles, 'o')
+ output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx,
+ self.sp_stream, self.overlap_handles, 'o')
#out e.g., [s/p::h]
return output
diff --git a/docs/Gemfile b/docs/Gemfile
new file mode 100644
index 000000000000..f40c61e4575f
--- /dev/null
+++ b/docs/Gemfile
@@ -0,0 +1,24 @@
+source "https://rubygems.org"
+
+gem 'github-pages', group: :jekyll_plugins
+
+# If you have any plugins, put them here!
+group :jekyll_plugins do
+ gem "jekyll-feed"
+ gem "jekyll-paginate"
+ gem "jekyll-remote-theme"
+ gem "jekyll-include-cache"
+ gem "minimal-mistakes-jekyll"
+end
+
+# Windows and JRuby does not include zoneinfo files, so bundle the tzinfo-data gem
+# and associated library.
+install_if -> { RUBY_PLATFORM =~ %r!mingw|mswin|java! } do
+ gem "tzinfo", "~> 1.2"
+ gem "tzinfo-data"
+end
+
+# Performance-booster for watching directories on Windows
+gem "wdm", "~> 0.1.1", :install_if => Gem.win_platform?
+
+gem "webrick", "~> 1.8"
diff --git a/docs/README.md b/docs/README.md
index 0c3aaaeda600..7333a119c7be 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -42,6 +42,16 @@ We now need to install the required Ruby packages for the website.
bundle install
```
+Depending on your environment, you may need to add `webrick` to avoid the following [error](https://talk.jekyllrb.com/t/load-error-cannot-load-such-file-webrick/5417/6):
+
+> gems/gems/jekyll-3.9.5/lib/jekyll/commands/serve/servlet.rb:3:in `require': cannot load such file -- webrick (LoadError)
+
+
+```
+bundle add webrick
+```
+
+
You can now start a local webserver via:
```
bundle exec jekyll serve
diff --git a/op_builder/evoformer_attn.py b/op_builder/evoformer_attn.py
index 6e7721f94e01..af3aa7429775 100644
--- a/op_builder/evoformer_attn.py
+++ b/op_builder/evoformer_attn.py
@@ -41,18 +41,21 @@ def nvcc_args(self):
args.append(f"-DGPU_ARCH={major}{minor}")
return args
- def is_compatible(self, verbose=True):
+ def is_compatible(self, verbose=False):
try:
import torch
except ImportError:
- self.warning("Please install torch if trying to pre-compile kernels")
+ if verbose:
+ self.warning("Please install torch if trying to pre-compile kernels")
return False
if self.cutlass_path is None:
- self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH")
+ if verbose:
+ self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH")
return False
with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f:
if '3.1.0' not in f.read():
- self.warning("Please use CUTLASS version >= 3.1.0")
+ if verbose:
+ self.warning("Please use CUTLASS version >= 3.1.0")
return False
cuda_okay = True
if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda
@@ -60,10 +63,12 @@ def is_compatible(self, verbose=True):
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
if cuda_capability < 7:
- self.warning("Please use a GPU with compute capability >= 7.0")
+ if verbose:
+ self.warning("Please use a GPU with compute capability >= 7.0")
cuda_okay = False
if torch_cuda_major < 11 or sys_cuda_major < 11:
- self.warning("Please use CUDA 11+")
+ if verbose:
+ self.warning("Please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay
diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py
index c7d2e72b5408..40cf504c2c83 100644
--- a/op_builder/fp_quantizer.py
+++ b/op_builder/fp_quantizer.py
@@ -22,11 +22,12 @@ def __init__(self, name=None):
def absolute_name(self):
return f'deepspeed.ops.fp_quantizer.{self.NAME}_op'
- def is_compatible(self, verbose=True):
+ def is_compatible(self, verbose=False):
try:
import torch
except ImportError:
- self.warning("Please install torch if trying to pre-compile inference kernels")
+ if verbose:
+ self.warning("Please install torch if trying to pre-compile inference kernels")
return False
cuda_okay = True
@@ -35,17 +36,20 @@ def is_compatible(self, verbose=True):
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
if cuda_capability < 8:
- self.warning("NVIDIA Inference is only supported on Ampere and newer architectures")
+ if verbose:
+ self.warning("NVIDIA Inference is only supported on Ampere and newer architectures")
cuda_okay = False
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
- self.warning("On Ampere and higher architectures please use CUDA 11+")
+ if verbose:
+ self.warning("On Ampere and higher architectures please use CUDA 11+")
cuda_okay = False
try:
import triton
except ImportError:
- self.warning(f"please install triton==2.3.0 or 2.3.1 if you want to use the FP Quantizer Kernels")
+ if verbose:
+ self.warning(f"please install triton==2.3.0 or 2.3.1 if you want to use the FP Quantizer Kernels")
return False
# triton 2.3.0 and 2.3.1 are okay and the only versions released in 2.3.x before 3.x was released
@@ -59,9 +63,10 @@ def is_compatible(self, verbose=True):
triton_mismatch = major != "2" or minor != "3"
if triton_mismatch:
- self.warning(
- f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.0 and 2.3.1 are known to be compatible with these kernels"
- )
+ if verbose:
+ self.warning(
+ f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.0 and 2.3.1 are known to be compatible with these kernels"
+ )
return False
return super().is_compatible(verbose) and cuda_okay
diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py
index d1957f39d9a8..45e8628e669f 100755
--- a/op_builder/inference_core_ops.py
+++ b/op_builder/inference_core_ops.py
@@ -23,7 +23,8 @@ def is_compatible(self, verbose=True):
try:
import torch
except ImportError:
- self.warning("Please install torch if trying to pre-compile inference kernels")
+ if verbose:
+ self.warning("Please install torch if trying to pre-compile inference kernels")
return False
cuda_okay = True
@@ -32,11 +33,13 @@ def is_compatible(self, verbose=True):
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
if cuda_capability < 6:
- self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
+ if verbose:
+ self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
cuda_okay = False
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
- self.warning("On Ampere and higher architectures please use CUDA 11+")
+ if verbose:
+ self.warning("On Ampere and higher architectures please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay
diff --git a/op_builder/inference_cutlass_builder.py b/op_builder/inference_cutlass_builder.py
index 51f7931d9435..fda6e74bbf6a 100644
--- a/op_builder/inference_cutlass_builder.py
+++ b/op_builder/inference_cutlass_builder.py
@@ -22,7 +22,8 @@ def is_compatible(self, verbose=True):
try:
import torch
except ImportError:
- self.warning("Please install torch if trying to pre-compile inference kernels")
+ if verbose:
+ self.warning("Please install torch if trying to pre-compile inference kernels")
return False
cuda_okay = True
@@ -31,11 +32,13 @@ def is_compatible(self, verbose=True):
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
if cuda_capability < 6:
- self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
+ if verbose:
+ self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
cuda_okay = False
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
- self.warning("On Ampere and higher architectures please use CUDA 11+")
+ if verbose:
+ self.warning("On Ampere and higher architectures please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay
diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py
index ec7cab91885f..a4e365786a2b 100644
--- a/op_builder/ragged_ops.py
+++ b/op_builder/ragged_ops.py
@@ -23,7 +23,8 @@ def is_compatible(self, verbose=True):
try:
import torch
except ImportError:
- self.warning("Please install torch if trying to pre-compile inference kernels")
+ if verbose:
+ self.warning("Please install torch if trying to pre-compile inference kernels")
return False
cuda_okay = True
@@ -32,11 +33,13 @@ def is_compatible(self, verbose=True):
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
if cuda_capability < 6:
- self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
+ if verbose:
+ self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
cuda_okay = False
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
- self.warning("On Ampere and higher architectures please use CUDA 11+")
+ if verbose:
+ self.warning("On Ampere and higher architectures please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay
diff --git a/op_builder/ragged_utils.py b/op_builder/ragged_utils.py
index 89450e1fd30d..a855f072af8c 100755
--- a/op_builder/ragged_utils.py
+++ b/op_builder/ragged_utils.py
@@ -23,7 +23,8 @@ def is_compatible(self, verbose=True):
try:
import torch
except ImportError:
- self.warning("Please install torch if trying to pre-compile inference kernels")
+ if verbose:
+ self.warning("Please install torch if trying to pre-compile inference kernels")
return False
cuda_okay = True
@@ -32,11 +33,13 @@ def is_compatible(self, verbose=True):
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
if cuda_capability < 6:
- self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
+ if verbose:
+ self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
cuda_okay = False
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
- self.warning("On Ampere and higher architectures please use CUDA 11+")
+ if verbose:
+ self.warning("On Ampere and higher architectures please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay
diff --git a/op_builder/sparse_attn.py b/op_builder/sparse_attn.py
index 188d257ff4ef..2385adc8fe9c 100644
--- a/op_builder/sparse_attn.py
+++ b/op_builder/sparse_attn.py
@@ -27,45 +27,51 @@ def sources(self):
def cxx_args(self):
return ['-O2', '-fopenmp']
- def is_compatible(self, verbose=True):
+ def is_compatible(self, verbose=False):
# Check to see if llvm and cmake are installed since they are dependencies
#required_commands = ['llvm-config|llvm-config-9', 'cmake']
#command_status = list(map(self.command_exists, required_commands))
#deps_compatible = all(command_status)
if self.is_rocm_pytorch():
- self.warning(f'{self.NAME} is not compatible with ROCM')
+ if verbose:
+ self.warning(f'{self.NAME} is not compatible with ROCM')
return False
try:
import torch
except ImportError:
- self.warning(f"unable to import torch, please install it first")
+ if verbose:
+ self.warning(f"unable to import torch, please install it first")
return False
# torch-cpu will not have a cuda version
if torch.version.cuda is None:
cuda_compatible = False
- self.warning(f"{self.NAME} cuda is not available from torch")
+ if verbose:
+ self.warning(f"{self.NAME} cuda is not available from torch")
else:
major, minor = torch.version.cuda.split('.')[:2]
cuda_compatible = (int(major) == 10 and int(minor) >= 1) or (int(major) >= 11)
if not cuda_compatible:
- self.warning(f"{self.NAME} requires CUDA version 10.1+")
+ if verbose:
+ self.warning(f"{self.NAME} requires CUDA version 10.1+")
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
torch_compatible = (TORCH_MAJOR == 1 and TORCH_MINOR >= 5)
if not torch_compatible:
- self.warning(
- f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}')
+ if verbose:
+ self.warning(
+ f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}')
try:
import triton
except ImportError:
# auto-install of triton is broken on some systems, reverting to manual install for now
# see this issue: https://github.com/microsoft/DeepSpeed/issues/1710
- self.warning(f"please install triton==1.0.0 if you want to use sparse attention")
+ if verbose:
+ self.warning(f"please install triton==1.0.0 if you want to use sparse attention")
return False
if pkg_version:
@@ -76,7 +82,9 @@ def is_compatible(self, verbose=True):
triton_mismatch = installed_triton != "1.0.0"
if triton_mismatch:
- self.warning(f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible")
+ if verbose:
+ self.warning(
+ f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible")
return False
return super().is_compatible(verbose) and torch_compatible and cuda_compatible
diff --git a/op_builder/spatial_inference.py b/op_builder/spatial_inference.py
index 59caf57f938d..8a6b36cce0b0 100644
--- a/op_builder/spatial_inference.py
+++ b/op_builder/spatial_inference.py
@@ -21,7 +21,8 @@ def is_compatible(self, verbose=True):
try:
import torch
except ImportError:
- self.warning("Please install torch if trying to pre-compile inference kernels")
+ if verbose:
+ self.warning("Please install torch if trying to pre-compile inference kernels")
return False
cuda_okay = True
@@ -31,7 +32,8 @@ def is_compatible(self, verbose=True):
cuda_capability = torch.cuda.get_device_properties(0).major
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
- self.warning("On Ampere and higher architectures please use CUDA 11+")
+ if verbose:
+ self.warning("On Ampere and higher architectures please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay
diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py
index 5ee902289448..88b77499cc0e 100755
--- a/op_builder/transformer_inference.py
+++ b/op_builder/transformer_inference.py
@@ -21,7 +21,8 @@ def is_compatible(self, verbose=True):
try:
import torch
except ImportError:
- self.warning("Please install torch if trying to pre-compile inference kernels")
+ if verbose:
+ self.warning("Please install torch if trying to pre-compile inference kernels")
return False
cuda_okay = True
@@ -30,11 +31,13 @@ def is_compatible(self, verbose=True):
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major
if cuda_capability < 6:
- self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
+ if verbose:
+ self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
cuda_okay = False
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
- self.warning("On Ampere and higher architectures please use CUDA 11+")
+ if verbose:
+ self.warning("On Ampere and higher architectures please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay
diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt
index c0fc5dba9d33..2e2c880fbeb0 100644
--- a/requirements/requirements-dev.txt
+++ b/requirements/requirements-dev.txt
@@ -1,5 +1,5 @@
accelerate
-clang-format==16.0.2
+clang-format==18.1.3
comet_ml>=3.41.0
deepspeed-kernels ; sys_platform == 'linux'
docutils<0.18
diff --git a/tests/unit/linear/test_ctx.py b/tests/unit/linear/test_ctx.py
new file mode 100644
index 000000000000..e03d13fd6ce2
--- /dev/null
+++ b/tests/unit/linear/test_ctx.py
@@ -0,0 +1,106 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+import deepspeed
+import pytest
+from unit.common import DistributedTest
+
+import deepspeed.comm as dist
+from deepspeed.linear import LoRAConfig, init_lora
+from deepspeed.linear.optimized_linear import LoRAOptimizedLinear
+from unit.simple_model import random_dataloader, SimpleModel
+
+try:
+ import transformers
+except ImportError:
+ transformers = None
+
+if transformers is None:
+ pytest.skip("transformers is required for this test", allow_module_level=True)
+
+
+def injection_assert(model):
+ # pick out random linear that should have been replaced and initialized
+ q_proj = model.model.layers[1].self_attn.q_proj
+
+ assert isinstance(q_proj, LoRAOptimizedLinear), "injection did not happen"
+ assert q_proj._initialized, "lora was not initialized properly"
+ assert isinstance(q_proj.lora_weight_1, torch.nn.Linear)
+ assert isinstance(q_proj.lora_weight_2, torch.nn.Linear)
+
+
+class TestEngine(DistributedTest):
+ world_size = 2
+
+ def test_model(self):
+ lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2)
+ quant_config = None
+ hidden_dim = 64
+ nlayers = 4
+
+ with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config):
+ model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers)
+
+ init_lora(model)
+
+ model_norms = [model.linears[i].weight.norm().item() for i in range(nlayers)]
+
+ ds_config = {
+ "train_batch_size": 2,
+ "steps_per_print": 1,
+ "bf16": {
+ "enabled": True
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 0.00015
+ }
+ },
+ "zero_optimization": {
+ "stage": 1
+ }
+ }
+ model, *_ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters())
+
+ engine_norms = [model.module.linears[i].weight.norm().item() for i in range(nlayers)]
+
+ # Ensure that sharded weights are not broadcast during engine init
+ assert engine_norms == model_norms, f"{dist.get_rank()=} base weight norms are not the same after engine init, {engine_norms=} != {model_norms=}"
+
+ data_loader = random_dataloader(model=model,
+ total_samples=50,
+ hidden_dim=hidden_dim,
+ device=model.device,
+ dtype=torch.bfloat16)
+ for n, batch in enumerate(data_loader):
+ loss = model(batch[0], batch[1])
+ model.backward(loss)
+ model.step()
+
+
+class TestInitTransformers(DistributedTest):
+ world_size = 2
+
+ def test_pretrained_init(self):
+ lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2)
+ quant_config = None
+
+ with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config):
+ model = transformers.AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-Llama-3")
+
+ injection_assert(model)
+
+ def test_config_init(self):
+ lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2)
+ quant_config = None
+
+ config = transformers.AutoConfig.from_pretrained("llamafactory/tiny-random-Llama-3")
+
+ with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config):
+ model = transformers.AutoModelForCausalLM.from_config(config)
+
+ injection_assert(model)