Skip to content

Commit

Permalink
Capture short kernel sequences to graph (microsoft#4318)
Browse files Browse the repository at this point in the history
**Motivation:**
1. This is a series of cases where short kernel sequences are launched
and executed serially(no dynamic shape), with the launch overhead being
much higher than the execution overhead. We can use a graph to solve
this problem. Compared to ```multi-tensor-apply```, using graph is more
concise and only requires PyTorch as a dependency.
2. Some device software stacks also support lazy-mode PyTorch, enabling
full utilization of the compiler to perform graph optimization. However,
in lazy mode, operation accumulation time (host time) could become
significantly higher compared to device time in such scenario, and
devices are usually not well utilized. By using the same API(after
adding to accelerator cc @delock ) with cuda graph, this issue could
also be resolved.

**Change:**
We modified three functions, 
```update_hp_grads```. Here, we executed the operations for the CPU and GPU separately because the graph is unable to record the execution of CPU operations. Additionally, the data input required by the graph must not have its address modified, or the address modification must be captured by the capture operation(In this case, set ```replay_first_step``` to ```True```). Therefore, we changed ```grad=None``` to ```grad.zero_()```. Similarly, we have also placed some inputs that require fixed addresses in the ```graph_cache``` 

For ```clip_tensors_by_global_norm```, ```clip_coef``` is a scalar with a non-fixed value, so it needs to be moved to the GPU when using a graph.


For ```total_norm = sum ([t. data. float (). norm (norm_type). item () * * norm_type for t in input_tensors])```, ```item () ```, synchronous operation is also not supported by graph. We directly put the ```sum``` and ```* * norm_type``` on the GPU to execute the computation.

Other similar scenarios can also use this ```graph_process()```, or a slightly modified version of ```graph_process()```

you can checkout
[4abab21](microsoft@4abab21)  and set it to True here to do some benchmarking.
microsoft@4abab21#diff-f8f0b3feb55b0374615405e542c1c3e0f017982b177c46c562bf688532ac935cR42

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
2 people authored and amaurya committed Feb 17, 2024
1 parent 48f3e68 commit 1dff0c4
Show file tree
Hide file tree
Showing 14 changed files with 189 additions and 47 deletions.
13 changes: 13 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,19 @@ def communication_backend_name(self):
def is_triton_supported(self):
...

# Graph operations
@abc.abstractmethod
def create_graph(self):
...

@abc.abstractmethod
def capture_to_graph(self, graph, pool=None, stream=None):
...

@abc.abstractmethod
def replay_graph(self, graph):
...

# Tensor operations
@property
@abc.abstractmethod
Expand Down
12 changes: 11 additions & 1 deletion accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,18 @@ def is_fp16_supported(self):
def supported_dtypes(self):
return [torch.float, torch.bfloat16]

# Tensor operations
# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph, pool=None, stream=None):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations
@property
def BFloat16Tensor(self):
return torch.BFloat16Tensor
Expand Down
11 changes: 11 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@ def is_triton_supported(self):
else:
return False

# Graph operations
def create_graph(self):
return torch.cuda.CUDAGraph()

def capture_to_graph(self, graph, pool=None, stream=None):
return torch.cuda.graph(graph, pool, stream)

def replay_graph(self, graph):
graph.replay()
return

# Tensor operations

@property
Expand Down
11 changes: 11 additions & 0 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ def communication_backend_name(self):
def is_triton_supported(self):
return False

# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph, pool=None, stream=None):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations
@property
def BFloat16Tensor(self):
Expand Down
11 changes: 11 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,17 @@ def communication_backend_name(self):
def is_triton_supported(self):
return False

# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph, pool=None, stream=None):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations

@property
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,11 @@ def _create_cuda_graph(self, *inputs, **kwargs):
get_accelerator().current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs = torch.cuda.CUDAGraph()
self._cuda_graphs = get_accelerator().create_graph()
self.static_inputs = inputs
self.static_kwargs = kwargs

with torch.cuda.graph(self._cuda_graphs):
with get_accelerator().capture_to_graph(self._cuda_graphs):
self.static_output = self.module(*self.static_inputs, **self.static_kwargs)

self.cuda_graph_created = True
Expand All @@ -547,7 +547,7 @@ def _graph_replay(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._cuda_graphs.replay()
get_accelerator().replay_graph(self._cuda_graphs)
return self.static_output

def model_times(self):
Expand Down
7 changes: 4 additions & 3 deletions deepspeed/model_implementations/diffusers/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import torch
from deepspeed.accelerator import get_accelerator
from ..features.cuda_graph import CUDAGraph


Expand All @@ -29,7 +30,7 @@ def _graph_replay(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._cuda_graphs.replay()
get_accelerator().replay_graph(self._cuda_graphs)
return self.static_output

def forward(self, *inputs, **kwargs):
Expand All @@ -53,11 +54,11 @@ def _create_cuda_graph(self, *inputs, **kwargs):
torch.cuda.current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs = torch.cuda.CUDAGraph()
self._cuda_graphs = get_accelerator().create_graph()
self.static_inputs = inputs
self.static_kwargs = kwargs

with torch.cuda.graph(self._cuda_graphs):
with get_accelerator().capture_to_graph(self._cuda_graphs):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)

self.cuda_graph_created = True
Expand Down
19 changes: 10 additions & 9 deletions deepspeed/model_implementations/diffusers/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import torch
from deepspeed.accelerator import get_accelerator
from ..features.cuda_graph import CUDAGraph


Expand All @@ -27,7 +28,7 @@ def _graph_replay_decoder(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_decoder_kwargs[k].copy_(kwargs[k])
self._decoder_cuda_graph.replay()
get_accelerator().replay_graph(self._decoder_cuda_graph)
return self.static_decoder_output

def _decode(self, x, return_dict=True, generator=None):
Expand All @@ -43,11 +44,11 @@ def _create_cuda_graph_decoder(self, *inputs, **kwargs):
torch.cuda.current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._decoder_cuda_graph = torch.cuda.CUDAGraph()
self._decoder_cuda_graph = get_accelerator().create_graph()
self.static_decoder_inputs = inputs
self.static_decoder_kwargs = kwargs

with torch.cuda.graph(self._decoder_cuda_graph):
with get_accelerator().capture_to_graph(self._decoder_cuda_graph):
self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs)

self.decoder_cuda_graph_created = True
Expand All @@ -70,7 +71,7 @@ def _graph_replay_encoder(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_encoder_kwargs[k].copy_(kwargs[k])
self._encoder_cuda_graph.replay()
get_accelerator().replay_graph(self._encoder_cuda_graph)
return self.static_encoder_output

def _encode(self, x, return_dict=True):
Expand All @@ -86,11 +87,11 @@ def _create_cuda_graph_encoder(self, *inputs, **kwargs):
torch.cuda.current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._encoder_cuda_graph = torch.cuda.CUDAGraph()
self._encoder_cuda_graph = get_accelerator().create_graph()
self.static_encoder_inputs = inputs
self.static_encoder_kwargs = kwargs

with torch.cuda.graph(self._encoder_cuda_graph):
with get_accelerator().capture_to_graph(self._encoder_cuda_graph):
self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs)

self.encoder_cuda_graph_created = True
Expand All @@ -113,7 +114,7 @@ def _graph_replay(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._all_cuda_graph.replay()
get_accelerator().replay_graph(self._all_cuda_graph)
return self.static_output

def forward(self, *inputs, **kwargs):
Expand All @@ -137,11 +138,11 @@ def _create_cuda_graph(self, *inputs, **kwargs):
torch.cuda.current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._all_cuda_graph = torch.cuda.CUDAGraph()
self._all_cuda_graph = get_accelerator().create_graph()
self.static_inputs = inputs
self.static_kwargs = kwargs

with torch.cuda.graph(self._all_cuda_graph):
with get_accelerator().capture_to_graph(self._all_cuda_graph):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)

self.all_cuda_graph_created = True
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/model_implementations/transformers/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _graph_replay(self, *inputs, **kwargs):
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[self.iter][k].copy_(kwargs[k])
self._cuda_graphs[self.iter].replay()
get_accelerator().replay_graph(self._cuda_graphs[self.iter])
return self.static_output[self.iter]

def forward(self, *inputs, **kwargs):
Expand All @@ -63,11 +63,11 @@ def _create_cuda_graph(self, *inputs, **kwargs):
torch.cuda.current_stream().wait_stream(cuda_stream)

# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph()
self._cuda_graphs[self.iter] = get_accelerator().create_graph()
self.static_inputs[self.iter] = inputs
self.static_kwargs[self.iter] = kwargs

with torch.cuda.graph(self._cuda_graphs[self.iter]):
with get_accelerator().capture_to_graph(self._cuda_graphs[self.iter]):
self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter],
**self.static_kwargs[self.iter])

Expand Down
49 changes: 32 additions & 17 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage)
is_model_parallel_parameter, see_memory_usage, graph_process)

from deepspeed.utils import link_hp_params, fragment_address
from deepspeed.checkpoint import enable_universal_checkpoint
Expand All @@ -38,7 +38,8 @@ def __init__(self,
allgather_bucket_size=5000000000,
dp_process_group=None,
timers=None,
grad_acc_dtype=None):
grad_acc_dtype=None,
graph_harvesting=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self,
self.fp32_groups_has_gradients = []

self.group_paddings = []

self.graph_harvesting = graph_harvesting
if self.using_real_optimizer:
self._setup_for_real_optimizer()

Expand Down Expand Up @@ -248,15 +249,17 @@ def step(self, closure=None):

all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),
mpu=self.mpu,
norm_type=self.norm_type)
norm_type=self.norm_type,
use_graph=self.graph_harvesting)
self._global_grad_norm = all_groups_norm

assert all_groups_norm > 0.
if self.clip_grad > 0.:
clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True),
max_norm=self.clip_grad,
global_norm=all_groups_norm,
mpu=self.mpu)
mpu=self.mpu,
use_graph=self.graph_harvesting)

self.optimizer.step()

Expand All @@ -281,23 +284,33 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg

@torch.no_grad()
def update_hp_grads(self, clear_lp_grads=False):

def _update_hp_grads_func(clear_lp_grads=False):
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if lp.grad is None:
continue
hp_grad = self.fp32_groups_gradients[i][j]
assert hp_grad is not None, \
f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]'
hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
lp._hp_grad = hp_grad
self.fp32_groups_has_gradients[i][j] = True
# clear gradients
if clear_lp_grads:
lp.grad._zero()

if self.graph_harvesting:
graph_process(False, _update_hp_grads_func, clear_lp_grads)
else:
_update_hp_grads_func(clear_lp_grads)
#cpu op
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if lp.grad is None:
continue

hp_grad = self.fp32_groups_gradients[i][j]
assert hp_grad is not None, \
f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]'

hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
lp._hp_grad = hp_grad
self.fp32_groups_has_gradients[i][j] = True

# clear gradients
if clear_lp_grads:
lp.grad = None

@torch.no_grad()
def get_grads_for_reduction(self):
return self.fp32_groups_gradients_flat
Expand Down Expand Up @@ -348,7 +361,9 @@ def clear_hp_grads(self):
def clear_lp_grads(self):
for group in self.bf16_groups:
for param in group:
param.grad = None
if param.grad is not None:
# Using zero_() fixed memory address for graph replay
param.grad.zero_()

def state_dict(self):
state_dict = {}
Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def get_gradient_clipping(param_dict):
return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT)


def get_graph_harvesting(param_dict):
return get_scalar_param(param_dict, GRAPH_HARVESTING, GRAPH_HARVESTING_DEFAULT)


def get_sparse_attention(param_dict):
if SPARSE_ATTENTION in param_dict.keys():
sparsity = param_dict[SPARSE_ATTENTION]
Expand Down Expand Up @@ -823,6 +827,7 @@ def _initialize_params(self, param_dict):
self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict)

self.compression_config = get_compression_config(param_dict)
self.graph_harvesting = get_graph_harvesting(param_dict)

self.optimizer_name = get_optimizer_name(param_dict)
if (self.optimizer_name is not None and self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS):
Expand Down
12 changes: 12 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@
GRADIENT_CLIPPING = 'gradient_clipping'
GRADIENT_CLIPPING_DEFAULT = 0.

#########################################
# Capture graph for short kernels sequences
#########################################
# Graph harvesting. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
GRAPH_HARVESTING_FORMAT = '''
Graph harvesting should be enabled as:
"graph_harvesting": true
'''
GRAPH_HARVESTING = 'graph_harvesting'
GRAPH_HARVESTING_DEFAULT = False

#########################################
# Communication data type
#########################################
Expand Down
Loading

0 comments on commit 1dff0c4

Please sign in to comment.