Skip to content

Commit

Permalink
Merge branch 'master' into tohtana/offload_zero_buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana authored Aug 22, 2024
2 parents c749b05 + 0f0f231 commit 36d6e10
Show file tree
Hide file tree
Showing 20 changed files with 375 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/xpu-max1100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ on:
- "deepspeed/runtime/zero/parameter_offload.py"
- "deepspeed/runtime/pipe/engine.py"
- "deepspeed/runtime/utils.py"
- "opbuilder/xpu/**"
- "op_builder/xpu/**"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>



* [2024/08] [DeepSpeed on Windows](https://github.com/microsoft/DeepSpeed/tree/master/blogs/windows/08-2024/README.md) [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/windows/08-2024/japanese/README.md)]
* [2024/08] [DeepNVMe: Improving DL Applications through I/O Optimizations](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-gds/README.md) [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-gds/japanese/README.md)]
* [2024/07] [DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/README.md) [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/japanese/README.md)]
* [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)]
Expand Down
123 changes: 123 additions & 0 deletions blogs/windows/08-2024/japanese/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
<div align="center">

# DeepSpeedのWindowsサポート

</div>

# はじめに

DeepSpeedは、分散学習と推論を簡単かつ効率的に行うための人気のあるオープンソースの深層学習最適化ライブラリです。DeepSpeedは、その豊富かつ高度な最適化機能(例:ZeRO、3D parallelism, MoEなど)のおかげで、Phi-3、Megatron-Turing-530B、BLOOM-176B、Arcticなどの最先端モデルの学習に広く利用されています。しかし、最も普及しているオペレーティングシステムであるMicrosoft Windowsをネイティブにサポートしていなかったため、多くのAI開発者やユーザーが、DeepSpeedの革新的な機能を利用できない状態でした。この問題を解決するため、DeepSpeedの完全な機能をWindows上でネイティブに実行し、Linux上と同じ使いやすさを実現するための取り組みを開始しました。

このブログでは、この取り組みの最初の成果をお知らせします。現在、DeepSpeedはWindowsにインストールし、単一GPUでの学習、ファインチューニング、および推論をネイティブに実行できるようになりました。ここで重要なこととして、インストールと利用は、Linuxとまったく同じように行えます。ファインチューニングと推論のワークロードを通じて、HuggingFace Transformers との統合、LoRAのサポート、CPUオフロードの3つの重要なDeepSpeedの機能が、正しく動作していることが確認できました。このWindowsサポートは、バージョン0.14.5以降で利用可能です。このブログの残りの部分では、これらの成果を示す例を紹介します。

# テスト環境

Windows 11 Version 23H2 および Build 22631.3880 を実行している Surface Laptop Studio 2 でテストを行いました。このハードウェアには、4GBのVRAMを搭載した NVIDIA RTX A2000 GPU が1つ搭載されています。また、PyTorchバージョン 2.3.0 および HuggingFace Transformersバージョン 4.41.2 を使用しました。使用したサンプルスクリプトは[DeepSpeedExamplesリポジトリ](https://github.com/microsoft/DeepSpeedExamples)から取得できます。以下の例を実行する前にリポジトリをクローンしてください。

# インストール

DeepSpeedは、2つの方法でWindowsにインストールできます。より簡単な方法は、pipパッケージマネージャーを使用することで、もう一方はソースからビルドする方法です。どちらの場合も、Python 3.xとCUDAサポート付きのPyTorchが必要です。

## pipを使用したインストール

DeepSpeedをインストールするには、単に次のコマンドを実行します: `pip install deepspeed`
これにより、最新バージョンのDeepSpeed(現時点では0.14.5)がインストールされます。Linux版とは異なり、Windows版ではすべてのオペレーターがすでにビルド済みであるため、CUDA SDKやC++コンパイラをインストールする必要はありません。

<div align="center">
<img src="../media/win_pip_install_deepspeed.png" style="width:6.5in;height:3.42153in" />
</div>

<div align="center">
pipによるWindowsへのDeepSpeedのインストール
</div>


## ソースからのビルド

ソースからDeepSpeedをビルドするには、DeepSpeedリポジトリをクローンし、コンパイルスクリプトである `build_win.bat` を実行する必要があります。

## インストールの検証

インストール方法にかかわらず、`ds_report`を実行してインストールが成功したかどうかを確認できます。出力は次のようになります:

<div align="center">
<img src="../media/ds_report.png" style="width:6.5in;height:3.42153in" />
</div>

<div align="center">
DeepSpeedのWindowsインストールを確認するds_reportの出力
</div>

# 事前学習の例

Windows上でDeepSpeedを使用した事前学習の例として、画像分類モデルCIFAR10と言語モデルBERTの実行例を示します。

## CIFAR10の事前学習

CIFAR10の事前学習に必要なスクリプトとコードは、次のパスにあります: `DeepSpeedExamples\training\cifar`

以下のコマンドを使用してCIFAR10の事前学習を開始できます: `deepspeed cifar10_deepspeed.py –deepspeed`

出力は次のようになります。

<div align="center">
<img src="../media/cifar10_training.png" style="width:6.5in;height:3.42153in" />
</div>

<div align="center">
DeepSpeedによるWindowsでのCIFAR10モデルの事前学習
</div>

## BERTの事前学習

BERTの事前学習に必要なスクリプトとコードは、次のパスにあります: `DeepSpeedExamples\training\HelloDeepSpeed`

以下のコマンドを使用してBERTの事前学習を開始できます: `deepspeed train_bert_ds.py --checkpoint_dir experiment_deepspeed`

出力は次のようになります。

<div align="center">
<img src="../media/bert_training.png" style="width:6.5in;height:3.42153in" />
</div>

<div align="center">
DeepSpeedによるWindowsでのBERTモデルの事前学習
</div>

# ファインチューニングの例

DeepSpeed-Chatアプリケーションの教師ありファインチューニング(supervised fine tuning; SFT)を使用して、ファインチューニングの機能を示します。LoRAおよびCPUオフロードメモリ最適化を有効にして、 HuggingFace の `facebook/opt-125m` モデルのSFTを実施します。この例を実行するためのコマンドラインは次のとおりです: `deepspeed training\step1_supervised_finetuning\main.py --model_name_or_path facebook/opt-125m --gradient_accumulation_steps 8 --lora_dim 128 --only_optimize_lora --print_loss --zero_stage 2 --deepspeed --dtype bf16 --offload --output_dir output`

出力は次のようになります。

<div align="center">
<img src="../media/opt125m_finetuning.png" style="width:6.5in;height:3.42153in" />
</div>

<div align="center">
DeepSpeedを使用したWindowsでの facebook/opt-125m モデルのファインチューニング
</div>

# 推論の例

推論の機能を示すために、トークン生成のためのZeRO-Inferenceを使用します。ZeRO-Inferenceは、CPUまたはNVMeメモリにオフロードすることで推論のハードウェアコストを削減します。ここでは、サンプルスクリプトを使用して、HuggingFaceのLlama-2-7Bモデルを使用したトークン生成を実行します。4GBのVRAMではモデルと生成処理の両方を実効するのに十分ではないため、モデルパラメータをCPUメモリにオフロードします。

次のコマンドラインを使用して、8トークンのプロンプトから32トークンを生成します: `deepspeed run_model.py --model meta-llama/Llama-2-7b-hf --batch-size 64 --prompt-len 8 --gen-len 32 --cpu-offload`

出力は次のようになります。

<div align="center">
<img src="../media/llama2-7b_inference.png" style="width:6.5in;height:3.42153in" />
</div>

<div align="center">
DeepSpeedのZeRO-InferenceによるWindowsでのLLAMA2-7Bのトークン生成
</div>

# まとめ

最も広く使われているオペレーティングシステムであるWindowsで、深層学習フレームワークであるDeepSpeedをネイティブに実行できるようにすることは、多くの人と組織が、今まさに進行中のAI革命の恩恵を受けるための重要な一歩です。このブログでは、この目標に向けたプロジェクトの、最初の成果を共有しました。Windowsのサポートは現在進行中のプロジェクトですが、今回の成果が多くのユーザにとって活用され、またさらに発展していけることを願っています。次のロードマップには、複数のGPUでの実行、モデルパラメータの量子化、パフォーマンスの詳細な分析が含まれます。

# 謝辞

このプロジェクトは、Costin Eseanu、Logan Adams、Elton Zheng、Reza Yazdani Aminabadi、Martin Cai、Olatunji Ruwaseを含むDeepSpeedメンバーによる大きな貢献の結果です。また、この機能を必要とし、様々な問題の解決策や、建設的なフィードバックを提供し、私たちと共に歩んでくれたDeepSpeedユーザーの重要な貢献に感謝します。
24 changes: 21 additions & 3 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def initialize(args=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
config=None,
mesh_param=None,
config_params=None):
"""Initialize the DeepSpeed Engine.
Expand Down Expand Up @@ -144,10 +145,22 @@ def initialize(args=None,
distributed_port=distributed_port,
dist_init_required=dist_init_required)

##TODO: combine reuse mpu as mesh device and vice versa
# Set config using config_params for backwards compat
if config is None and config_params is not None:
config = config_params

mesh_device = None
if mesh_param:
logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
#if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
elif config is not None:
if "sequence_parallel_size" in config and "data_parallel_size" in config:
logger.info(f"config to Initialize mesh device: {config}")
mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
("data_parallel", "sequence_parallel"))

# Check for deepscale_config for backwards compat
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
Expand All @@ -162,9 +175,8 @@ def initialize(args=None,
assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
config = args.deepspeed_config
assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"

if not isinstance(model, PipelineModule):
config_class = DeepSpeedConfig(config, mpu)
config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
if config_class.hybrid_engine.enabled:
engine = DeepSpeedHybridEngine(args=args,
model=model,
Expand All @@ -188,6 +200,7 @@ def initialize(args=None,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config=config,
mesh_device=mesh_device,
config_class=config_class)
else:
assert mpu is None, "mpu must be None with pipeline parallelism"
Expand All @@ -208,7 +221,12 @@ def initialize(args=None,
# Restore zero.Init context if necessary
zero.partition_parameters.restore_init_context()

return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
return_items = [
engine,
engine.optimizer,
engine.training_dataloader,
engine.lr_scheduler,
]
return tuple(return_items)


Expand Down
15 changes: 15 additions & 0 deletions deepspeed/comm/comm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,21 @@ def get_all_ranks_from_group(group=None):
return group_ranks


def initialize_mesh_device(mesh_shape, mesh_dim_names):
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
mesh_device = None
if hasattr(cdb, 'init_device_mesh'):
utils.logger.info(f"Initializing mesh device with backend {cdb.name} \
with shape {mesh_shape} and dim names {mesh_dim_names}")
mesh_device = cdb.init_device_mesh(mesh_shape, mesh_dim_names)
else:
if get_rank() == 0:
utils.logger.warning_once(f"Backend {cdb.name} does not support mesh device initialization")
return mesh_device


# Main DeepSpeed Comms. public API.
def init_distributed(dist_backend=None,
auto_mpi_discovery=True,
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/comm/torch.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,14 @@ def _reduce_op(self, op):
op = torch.distributed.ReduceOp.BXOR
return op

def init_device_mesh(self, mesh_shape, mesh_dim_names):
if not required_torch_version(min_version=2.2):
raise RuntimeError(f"Current torch version does not have device mesh"
f"api (torch.__version__: {torch.__version__})")
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(),
mesh_shape,
mesh_dim_names=mesh_dim_names)


# This will become a light-weight wrapper around torch.distributed functions
# TODO: create some example to show how this wrapper can help profile communication
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
for checkpoint in self._all_ckpt_paths:
inference_logger().info(f"Loading checkpoint: {checkpoint}")
checkpoint_sd = self._checkpoint_load_fn(checkpoint)

# If the model has tied embeddings, we need to make sure the lm_head weights are tied to the embeddings weights
if hasattr(self.model_config, "tie_word_embeddings") and self.model_config.tie_word_embeddings:
if self.model_config.model_type == "qwen2":
checkpoint_sd["lm_head.weight"] = checkpoint_sd["model.embed_tokens.weight"]

param_keys = list(checkpoint_sd.keys())
for param_name in param_keys:
param = checkpoint_sd[param_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@
} else if (4 == N_TOP_K) { \
constexpr int CONST_TOP_K = 4; \
__VA_ARGS__(); \
} else if (8 == N_TOP_K) { \
constexpr int CONST_TOP_K = 8; \
__VA_ARGS__(); \
} \
}()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class BlockedRotaryEmbeddings(DSKernelBase):

supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 80, 96, 128]
supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71]
supported_q_ratios = [1, 2, 4, 5, 6, 7, 8, 16, 29, 35, 36, 71]

def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype, rotary_dim: int,
theta_base: float) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ void launch_kv_rotary_kernel(T* kv_cache,
LAUNCH_KV_ROTARY_FOR_Q_RATIO(2)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(4)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(5)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(6)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(7)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(8)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(16)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(29)
Expand Down
Loading

0 comments on commit 36d6e10

Please sign in to comment.