Skip to content

Commit

Permalink
Merge branch 'master' into ds-llama
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Aug 5, 2024
2 parents 966ebd4 + 0584689 commit 6ec4ead
Show file tree
Hide file tree
Showing 70 changed files with 1,475 additions and 167 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/amd-mi200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm5.6
pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm6.0
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/cpu-torch-latest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ concurrency:

jobs:
unit-tests:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -50,5 +50,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3"
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3"
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.4"
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.4"
2 changes: 1 addition & 1 deletion .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:

# formatting and basic install on cpu-only machine
unit-tests:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/nv-mii.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:

- name: Install pytorch
run: |
pip3 install -U --cache-dir $TORCH_CACHE torch --index-url https://download.pytorch.org/whl/cu118
pip3 install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu118
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -46,7 +46,7 @@ jobs:
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout bdf36dc
git checkout v4.42.4
git rev-parse --short HEAD
pip install .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nv-pre-compile-ops.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ concurrency:

jobs:
unit-tests:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
container:
image: deepspeed/gh-builder:ubuntu1804-py38-torch1131-cu116

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.3" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.3" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.4" --cuda_ver="11.8"
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.4" --cuda_ver="11.8"
2 changes: 1 addition & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
pyVersion: ["3.7", "3.8", "3.9", "3.10"]
fail-fast: false

runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
container:
image: deepspeed/gh-builder:py${{ matrix.pyVersion }}

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
deploy:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
environment: release-env

steps:
Expand Down
15 changes: 8 additions & 7 deletions accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import functools
import os
import pkgutil
import importlib
Expand Down Expand Up @@ -196,31 +197,31 @@ def replay_graph(self, graph):
# Tensor operations
@property
def BFloat16Tensor(self):
return self.hpu.BFloat16Tensor
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu')

@property
def ByteTensor(self):
return self.hpu.ByteTensor
return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu')

@property
def DoubleTensor(self):
return self.hpu.DoubleTensor
return functools.partial(torch.tensor, dtype=torch.double, device='hpu')

@property
def FloatTensor(self):
return self.hpu.FloatTensor
return functools.partial(torch.tensor, dtype=torch.float, device='hpu')

@property
def HalfTensor(self):
return self.hpu.HalfTensor
return functools.partial(torch.tensor, dtype=torch.half, device='hpu')

@property
def IntTensor(self):
return self.hpu.IntTensor
return functools.partial(torch.tensor, dtype=torch.int, device='hpu')

@property
def LongTensor(self):
return self.hpu.LongTensor
return functools.partial(torch.tensor, dtype=torch.long, device='hpu')

def pin_memory(self, tensor, align_bytes=1):
return tensor.pin_memory(self.device())
Expand Down
6 changes: 5 additions & 1 deletion accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def is_synchronized_device(self):
return False

def use_host_timers(self):
return self.is_synchronized_device()
# WA XPU event will be consolidated in 2.5
if ipex.__version__ < '2.5':
return True
else:
return self.is_synchronized_device()

def resolves_data_dependency(self):
return self.is_synchronized_device()
Expand Down
3 changes: 3 additions & 0 deletions blogs/deepspeed-fastgen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,10 @@ We currently support the following model architectures in this alpha release of
* [Falcon](https://huggingface.co/models?other=falcon)
* [Mixtral](https://huggingface.co/models?other=mixtral)
* [Phi-2](https://huggingface.co/models?other=phi-msft)
* [Phi-3](https://huggingface.co/models?other=phi3)
* [Qwen](https://huggingface.co/models?other=qwen)
* [Qwen2](https://huggingface.co/models?other=qwen2)
* [Qwen2-MoE](https://huggingface.co/models?other=qwen2_moe)

All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer.

Expand Down
4 changes: 4 additions & 0 deletions blogs/deepspeed-fastgen/chinese/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ DeepSpeed-FastGen 是 [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII
* [LLaMA](https://huggingface.co/models?other=llama)[LLaMA-2](https://huggingface.co/models?other=llama-2)
* [Mistral](https://huggingface.co/models?other=mistral)
* [OPT](https://huggingface.co/models?other=opt)
* [Falcon](https://huggingface.co/models?other=falcon)
* [Mixtral](https://huggingface.co/models?other=mixtral)
* [Phi-2](https://huggingface.co/models?other=phi-msft)
* [Qwen](https://huggingface.co/models?other=qwen)

所有当前模型都利用了后端的 [HuggingFace](https://github.com/huggingface) API 来提供模型权重和模型对应的分词器。

Expand Down
88 changes: 88 additions & 0 deletions blogs/deepspeed-gds/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
<div align="center">

# DeepNVMe: Improving DL Applications through I/O Optimizations

</div>

# Introduction

Deep Learning (DL) continues to drive unprecedented advancements across important
Artificial Intelligence domains including language, speech, video, and multimodal applications.
A key factor to these advancements is dramatic scalability on multiple dimensions including model size,
sequence length, and hardware parallelism. From a system perspective, DL scalability puts significant
pressure on essential subsystems including computation, memory, communication, and storage. However,
existing DL optimization efforts have mostly neglected the storage subsystem, making I/O operations such
as data loading, model checkpointing, and offloading the main bottlenecks of large-scale DL. To address
this problem, DeepSpeed has created a suite of I/O optimizations collectively called DeepNVMe.

DeepNVMe improves the performance and efficiency of I/O-bound DL applications by accelerating I/O operations
and reducing hardware requirements. It achieves this by leveraging storage innovations such as Non-Volatile
Memory Express (NVMe) Solid Storage Devices (SSDs) and NVIDIA Magnum IO<sup>TM</sup> GPUDirect® Storage (GDS). In this
blog we show the benefits of DeepNVMe using microbenchmarks and an inference application. In experiments
conducted on an Azure NC96ads\_A100\_v4 VM, we observed that DeepNVMe saturates available NVMe bandwidth for
data transfers with GPU or CPU memory, achieving up to 10GB/sec reads and 5 GB/secs writes.

# Background
High-performance access to persistent storage is a common challenge in many computing domains, including DL. Thus, a significant number of hardware and software solutions have been proposed. DeepNVMe builds on three such solutions: (1) NVMe SSDs, (2) NVIDIA GDS, and (3) Linux Asynchronous I/O (libaio). We will briefly describe each of these technologies.

NVMe SSDs are Flash-based storage devices that are replacing much slower hard disk drives (HDD) as primary persistent storage in modern servers. For example, an Azure NC96ads\_A100\_v4 VM is equipped with four NVMe SSDs which are individually capable of 3.25 GB/sec reads and can be combined in a RAID-0 configuration for a theoretical aggregate read bandwidth of 13 GB/sec. NVIDIA GDS enables direct transfers between NVMe and GPU memory thus avoiding the inefficiencies of the traditional approach of using intermediate CPU memory (bounce buffer). NVIDIA GDS is generally available in CUDA versions 11.4 and above. Finally, libaio is an asynchronous I/O stack introduced in Linux to better extract raw performance of fast storage devices like NVMe SSDs compared to the traditional I/O stack.

# DeepNVMe: an Optimization Module for Deep Learning I/O

DeepNVMe is a Python module that we developed with two key design principles. First, it leverages the above discussed storage technologies to implement powerful optimizations such as non-blocking I/O operations, bulk submission of I/O operations, parallelization of an individual I/O operation, and a lightweight runtime. Second, it exposes these I/O optimizations through a simple POSIX-like interface to foster easy integration into DL applications while avoiding the complexities of the underlying technologies.

# Evaluation

Our experiments are conducted on an Azure NC96ads\_A100\_v4 VM with setup details summarized in Table 1. For multi-device experiments, the SSDs are combined in a RAID-0 configuration.

<img src="./media/table1.png" style="width:6.5in;height:3.42153in" />

<div align="center">
Table 1: Experimental setup details
</div>

## Microbenchmark Performance

We used three benchmarking tools for our evaluations. The first is fio, the popular I/O benchmarking tool written in C. The second is gdsio from NVIDIA for benchmarking GDS performance. The third is ds\_io, a Python tool that we created for easy integration with DeepNVMe and to be more representative of DL applications which are commonly Python-based.

## High-Performance I/O with CPU Buffers via NVMe Scaling

Our first set of microbenchmark evaluations used fio and ds\_io to measure the performance of transferring 1GB data between NVMe and CPU memory. We configure fio to use the libaio backend for these experiments1. The results are summarized in Figure 1, from which we make two observations. First, DeepNVMe demonstrates high performance as it roughly matches fio, despite being more representative of DL applications. Second, DeepNVMe scales I/O performance almost linearly with available NVMe bandwidth, achieving rates of 10GB/sec reads and 5GB/sec writes.

<img src="./media/figure1.png" style="width:6.5in;height:3.42153in" />

<div align="center">
Figure 1: Using DeepNVMe to scale data transfers between NVMe and CPU buffer
</div>

## High-Performance I/O with GPU Buffers via NVMe Scaling

Our second set of microbenchmark evaluations used gdsio and ds\_io to measure the performance of 1GB data transfer between NVMe and GPU memory. For this experiment, we configure ds\_io to use both the traditional bounce buffer approach and the more efficient GDS approach. The results are summarized in Figure 2, from which we make three observations. First, we see that GDS improves performance in DeepNVMe compared to the traditional bounce buffer approach, with up to 37% speedup. Second, DeepNVMe demonstrates high performance by matching (and sometimes surpassing) gdsio despite being more representative of DL applications. Third, we see that DeepNVMe, with and without GDS, scales I/O performance with available NVMe bandwidth. With GDS, DeepNVMe achieves a maximum of 9.6GB/sec reads and 5GB/sec writes, and without GDS achieves 7GB/sec reads and 4GB/sec writes.

<img src="./media/figure2.png" style="width:6.5in;height:3.42153in" />

<div align="center">
Figure 2: Using DeepNVMe to scale data transfers between NVMe and GPU memory
</div>

## ZeRO-Inference: Generative AI Performance

ZeRO-Inference is an AI democratization technology that reduces the hardware cost of inferencing massive models by using DeepNVMe to offload model weights to CPU or NVMe memory. ZeRO-Inference is well suited for throughput-oriented applications, such as offline inferencing, and for scenarios with limited hardware budget. We use token generation workload to evaluate DeepNVMe performance for NVMe offloading.

## High-Performance Offloading via NVMe Scaling

We measure the generation throughput of inferencing a LLAMA3-70B model on a single NVIDIA A100-80GB with a prompt length of 512, generation length of 32, and batch size of 96. We scale the number of NVMe SSDs from 1 to 4 and present the results for ZeRO-Inference with and without GDS in Figure 3. We make two observations from these results. First, GDS consistently provides better performance compared to the bounce buffer approach, achieving 10-18% faster token generation. Second, DeepNVMe, with and without GDS, scales generation performance with available NVMe bandwidth. With four NVMe SSDs, DeepNVMe achieves generation throughput rates of 7 tokens per second with GDS and 6 tokens per second without GDS. Our profiling results suggest that DeepNVMe will continue to scale with more NVMe bandwidth, making it an economic option for boosting generative application performance.

<img src="./media/figure3.png" style="width:6.5in;height:3.42153in" />

<div align="center">
Figure 3: Using DeepNVMe to scale LLAMA3-70B token generation performance with NVMe offloading.
</div>

# Summary

In this blog post, we introduced DeepNVMe, an I/O optimization technology created to tackle the emergence of I/O operations as key bottlenecks of Deep Learning scalability. DeepNVMe enables fast and efficient data transfers between persistent storage and DL application memory through optimizations built on popular storage technologies such as NVMe SSDs and NVIDIA GDS. We showed benefits of using DeepNVMe for LLAMA3-70B token generation on single A100-80GB GPU with NVMe offloading, for which it achieves up to 7 tokens per second in generation throughput on an Azure NC96ads\_A100\_v4 VM. DeepNVMe will be open-sourced and generally available in DeepSpeed versions >= [0.15.0](https://github.com/microsoft/DeepSpeed/releases/tag/v0.15.0). In future blogs, we will report DeepNVMe improvements for other I/O bound DL applications such as model checkpointing and data loading.


# Acknowlegements
This work is the result of a deep collaboration between Microsoft and NVIDIA. The contributors include Joe Mayer, Martin Cai, and Olatunji Ruwase from Microsoft; Kiran Modukuri, Vahid Noormofidi, Sourab Gupta, and Sandeep Joshi from Nivida.
Binary file added blogs/deepspeed-gds/media/figure1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-gds/media/figure2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-gds/media/figure3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-gds/media/table1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
99 changes: 82 additions & 17 deletions csrc/cpu/comm/shm_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ void initialize(int size, int rank)
if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); }
}

int get_rank(int group = 0) { return world_rank; }

int get_world_size(int group = 0) { return world_size; }
void inference_all_reduce_(torch::Tensor& data, int op);

// Success - return 0
// Fail (cannot hornor the request and need to fall back) - return -1
int inference_all_reduce(torch::Tensor& data, py::object op)
void inference_all_reduce_(torch::Tensor& data, int op)
{
if (!all_ranks_local_p) return -1;
assert(op == 0);
#ifdef DO_PROFILE
static double total_time = 0.0;
static double total_time_sq = 0.0;
Expand All @@ -67,11 +65,6 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
auto start = std::chrono::system_clock::now();
#endif

static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));

assert(py::int_(op.attr("value")) == ReduceOpSum);

auto numel = data.numel();

int data_size = 0;
Expand All @@ -84,7 +77,7 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
default: data_type_fallback = true;
}

if (data_type_fallback) return -1;
if (data_type_fallback) return;

all_reduce_outer_loop(data, numel, data_size);

Expand All @@ -109,13 +102,85 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
}
}
#endif
return 0;
return;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("initialize", &initialize, "shm initialize"); }

TORCH_LIBRARY(deepspeed, m)
{
m.def("inference_all_reduce(Tensor self) -> Tensor");
m.def("inference_all_reduce_(Tensor(a!) self) -> Tensor(a!)");
}

torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_)
{
torch::Tensor result_ = torch::empty_like(self_);
return result_;
}

torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_) { return self_; }

torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_)
{
TORCH_INTERNAL_ASSERT(self_.device().type() == torch::DeviceType::CPU);
torch::Tensor self_tensor = self_.contiguous();
inference_all_reduce_(self_tensor, 0);
return self_;
}

torch::Tensor inference_all_reduce_cpu(const torch::Tensor& self_)
{
torch::Tensor result = self_.clone();
inference_all_reduce__cpu(result);
return result;
}

#include <ATen/FunctionalTensorWrapper.h>
// The boilerplate functionalization logic, that teaches functionalization
// how to map x_() calls into x() calls.
// Long term, we'd like to not require users to write this logic.
// HOWEVER, if you have a custom op that is mutable,
// You will still need to write an out-of-place version of that op!
at::Tensor& inference_all_reduce__functionalization_glue(at::Tensor& x)
{
// We expect all tensor inputs to our op to be "functional tensors"
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(x));
// First, sync and unwrap and functional tensors
at::functionalization::impl::sync(x);
auto x_ = at::functionalization::impl::from_functional_tensor(x);
// Grab the dispatcher entry corresponding to the out-of-place op, "x"
static auto op_handle = c10::Dispatcher::singleton()
// specify namespace::op_name, op_overload_name
.findSchemaOrThrow("deepspeed::inference_all_reduce", "")
// Specify the C++ schema of the out-of-place op.
.typed<at::Tensor(const at::Tensor&)>();
// Next, redispatch to the out-of-place op, x() (user called x_, we call x)
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = op_handle.call(x_);
}
// Finally, tell functionalization about this mutation.
at::functionalization::impl::replace_(x, tmp_output);
at::functionalization::impl::commit_update(x);
at::functionalization::impl::sync(x);
return x;
}

TORCH_LIBRARY_IMPL(deepspeed, CPU, m)
{
m.impl("inference_all_reduce", inference_all_reduce_cpu);
m.impl("inference_all_reduce_", inference_all_reduce__cpu);
}

TORCH_LIBRARY_IMPL(deepspeed, Meta, m)
{
m.impl("inference_all_reduce", inference_all_reduce_meta);
m.impl("inference_all_reduce_", inference_all_reduce__meta);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
TORCH_LIBRARY_IMPL(deepspeed, Functionalize, m)
{
m.def("initialize", &initialize, "shm initialize");
m.def("get_rank", &get_rank, "get rank");
m.def("get_world_size", &get_world_size, "get world size");
m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation");
m.impl("inference_all_reduce_", inference_all_reduce__functionalization_glue);
}
Loading

0 comments on commit 6ec4ead

Please sign in to comment.