From 1dff0c40ceb4df7de9f44703661aba087fb64830 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 21 Dec 2023 04:51:36 +0800 Subject: [PATCH] Capture short kernel sequences to graph (#4318) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Motivation:** 1. This is a series of cases where short kernel sequences are launched and executed serially(no dynamic shape), with the launch overhead being much higher than the execution overhead. We can use a graph to solve this problem. Compared to ```multi-tensor-apply```, using graph is more concise and only requires PyTorch as a dependency. 2. Some device software stacks also support lazy-mode PyTorch, enabling full utilization of the compiler to perform graph optimization. However, in lazy mode, operation accumulation time (host time) could become significantly higher compared to device time in such scenario, and devices are usually not well utilized. By using the same API(after adding to accelerator cc @delock ) with cuda graph, this issue could also be resolved. **Change:** We modified three functions, ```update_hp_grads```. Here, we executed the operations for the CPU and GPU separately because the graph is unable to record the execution of CPU operations. Additionally, the data input required by the graph must not have its address modified, or the address modification must be captured by the capture operation(In this case, set ```replay_first_step``` to ```True```). Therefore, we changed ```grad=None``` to ```grad.zero_()```. Similarly, we have also placed some inputs that require fixed addresses in the ```graph_cache``` For ```clip_tensors_by_global_norm```, ```clip_coef``` is a scalar with a non-fixed value, so it needs to be moved to the GPU when using a graph. For ```total_norm = sum ([t. data. float (). norm (norm_type). item () * * norm_type for t in input_tensors])```, ```item () ```, synchronous operation is also not supported by graph. We directly put the ```sum``` and ```* * norm_type``` on the GPU to execute the computation. Other similar scenarios can also use this ```graph_process()```, or a slightly modified version of ```graph_process()``` you can checkout [4abab21](https://github.com/microsoft/DeepSpeed/pull/4318/commits/4abab212c8f5aef1eec4f8abe10b4262bb5a5c8a) and set it to True here to do some benchmarking. https://github.com/microsoft/DeepSpeed/pull/4318/commits/4abab212c8f5aef1eec4f8abe10b4262bb5a5c8a#diff-f8f0b3feb55b0374615405e542c1c3e0f017982b177c46c562bf688532ac935cR42 --------- Co-authored-by: Olatunji Ruwase --- accelerator/abstract_accelerator.py | 13 ++++ accelerator/cpu_accelerator.py | 12 +++- accelerator/cuda_accelerator.py | 11 +++ accelerator/mps_accelerator.py | 11 +++ accelerator/npu_accelerator.py | 11 +++ deepspeed/inference/engine.py | 6 +- .../model_implementations/diffusers/unet.py | 7 +- .../model_implementations/diffusers/vae.py | 19 +++--- .../transformers/clip_encoder.py | 6 +- deepspeed/runtime/bf16_optimizer.py | 49 ++++++++----- deepspeed/runtime/config.py | 5 ++ deepspeed/runtime/constants.py | 12 ++++ deepspeed/runtime/engine.py | 6 +- deepspeed/runtime/utils.py | 68 ++++++++++++++++--- 14 files changed, 189 insertions(+), 47 deletions(-) diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index 7c1f463009b2..8ba652a8de92 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -193,6 +193,19 @@ def communication_backend_name(self): def is_triton_supported(self): ... + # Graph operations + @abc.abstractmethod + def create_graph(self): + ... + + @abc.abstractmethod + def capture_to_graph(self, graph, pool=None, stream=None): + ... + + @abc.abstractmethod + def replay_graph(self, graph): + ... + # Tensor operations @property @abc.abstractmethod diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 05c4f6c650f0..a02777f5223b 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -198,8 +198,18 @@ def is_fp16_supported(self): def supported_dtypes(self): return [torch.float, torch.bfloat16] - # Tensor operations + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + # Tensor operations @property def BFloat16Tensor(self): return torch.BFloat16Tensor diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 4b94ddb6865c..2d74daecf3df 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -222,6 +222,17 @@ def is_triton_supported(self): else: return False + # Graph operations + def create_graph(self): + return torch.cuda.CUDAGraph() + + def capture_to_graph(self, graph, pool=None, stream=None): + return torch.cuda.graph(graph, pool, stream) + + def replay_graph(self, graph): + graph.replay() + return + # Tensor operations @property diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index f0d4cac721b5..f6303cf9890f 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -166,6 +166,17 @@ def communication_backend_name(self): def is_triton_supported(self): return False + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + # Tensor operations @property def BFloat16Tensor(self): diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 8bfd59cd2ad6..4e20445d9d32 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -176,6 +176,17 @@ def communication_backend_name(self): def is_triton_supported(self): return False + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + # Tensor operations @property diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index ccecc8376ad6..5cdd99ff0b90 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -531,11 +531,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): get_accelerator().current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs = torch.cuda.CUDAGraph() + self._cuda_graphs = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._cuda_graphs): + with get_accelerator().capture_to_graph(self._cuda_graphs): self.static_output = self.module(*self.static_inputs, **self.static_kwargs) self.cuda_graph_created = True @@ -547,7 +547,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._cuda_graphs.replay() + get_accelerator().replay_graph(self._cuda_graphs) return self.static_output def model_times(self): diff --git a/deepspeed/model_implementations/diffusers/unet.py b/deepspeed/model_implementations/diffusers/unet.py index 7da571975958..8d5ddd95437a 100644 --- a/deepspeed/model_implementations/diffusers/unet.py +++ b/deepspeed/model_implementations/diffusers/unet.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.accelerator import get_accelerator from ..features.cuda_graph import CUDAGraph @@ -29,7 +30,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._cuda_graphs.replay() + get_accelerator().replay_graph(self._cuda_graphs) return self.static_output def forward(self, *inputs, **kwargs): @@ -53,11 +54,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs = torch.cuda.CUDAGraph() + self._cuda_graphs = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._cuda_graphs): + with get_accelerator().capture_to_graph(self._cuda_graphs): self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) self.cuda_graph_created = True diff --git a/deepspeed/model_implementations/diffusers/vae.py b/deepspeed/model_implementations/diffusers/vae.py index 05084f1b985a..ce50ade647a8 100644 --- a/deepspeed/model_implementations/diffusers/vae.py +++ b/deepspeed/model_implementations/diffusers/vae.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.accelerator import get_accelerator from ..features.cuda_graph import CUDAGraph @@ -27,7 +28,7 @@ def _graph_replay_decoder(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_decoder_kwargs[k].copy_(kwargs[k]) - self._decoder_cuda_graph.replay() + get_accelerator().replay_graph(self._decoder_cuda_graph) return self.static_decoder_output def _decode(self, x, return_dict=True, generator=None): @@ -43,11 +44,11 @@ def _create_cuda_graph_decoder(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._decoder_cuda_graph = torch.cuda.CUDAGraph() + self._decoder_cuda_graph = get_accelerator().create_graph() self.static_decoder_inputs = inputs self.static_decoder_kwargs = kwargs - with torch.cuda.graph(self._decoder_cuda_graph): + with get_accelerator().capture_to_graph(self._decoder_cuda_graph): self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs) self.decoder_cuda_graph_created = True @@ -70,7 +71,7 @@ def _graph_replay_encoder(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_encoder_kwargs[k].copy_(kwargs[k]) - self._encoder_cuda_graph.replay() + get_accelerator().replay_graph(self._encoder_cuda_graph) return self.static_encoder_output def _encode(self, x, return_dict=True): @@ -86,11 +87,11 @@ def _create_cuda_graph_encoder(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._encoder_cuda_graph = torch.cuda.CUDAGraph() + self._encoder_cuda_graph = get_accelerator().create_graph() self.static_encoder_inputs = inputs self.static_encoder_kwargs = kwargs - with torch.cuda.graph(self._encoder_cuda_graph): + with get_accelerator().capture_to_graph(self._encoder_cuda_graph): self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs) self.encoder_cuda_graph_created = True @@ -113,7 +114,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._all_cuda_graph.replay() + get_accelerator().replay_graph(self._all_cuda_graph) return self.static_output def forward(self, *inputs, **kwargs): @@ -137,11 +138,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._all_cuda_graph = torch.cuda.CUDAGraph() + self._all_cuda_graph = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._all_cuda_graph): + with get_accelerator().capture_to_graph(self._all_cuda_graph): self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) self.all_cuda_graph_created = True diff --git a/deepspeed/model_implementations/transformers/clip_encoder.py b/deepspeed/model_implementations/transformers/clip_encoder.py index 8d9291896986..848a5b48dcf1 100644 --- a/deepspeed/model_implementations/transformers/clip_encoder.py +++ b/deepspeed/model_implementations/transformers/clip_encoder.py @@ -38,7 +38,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[self.iter][k].copy_(kwargs[k]) - self._cuda_graphs[self.iter].replay() + get_accelerator().replay_graph(self._cuda_graphs[self.iter]) return self.static_output[self.iter] def forward(self, *inputs, **kwargs): @@ -63,11 +63,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph() + self._cuda_graphs[self.iter] = get_accelerator().create_graph() self.static_inputs[self.iter] = inputs self.static_kwargs[self.iter] = kwargs - with torch.cuda.graph(self._cuda_graphs[self.iter]): + with get_accelerator().capture_to_graph(self._cuda_graphs[self.iter]): self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter], **self.static_kwargs[self.iter]) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 888505279290..0aefd1946c36 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -16,7 +16,7 @@ from deepspeed.git_version_info import version from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, - is_model_parallel_parameter, see_memory_usage) + is_model_parallel_parameter, see_memory_usage, graph_process) from deepspeed.utils import link_hp_params, fragment_address from deepspeed.checkpoint import enable_universal_checkpoint @@ -38,7 +38,8 @@ def __init__(self, allgather_bucket_size=5000000000, dp_process_group=None, timers=None, - grad_acc_dtype=None): + grad_acc_dtype=None, + graph_harvesting=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers @@ -81,7 +82,7 @@ def __init__(self, self.fp32_groups_has_gradients = [] self.group_paddings = [] - + self.graph_harvesting = graph_harvesting if self.using_real_optimizer: self._setup_for_real_optimizer() @@ -248,7 +249,8 @@ def step(self, closure=None): all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(), mpu=self.mpu, - norm_type=self.norm_type) + norm_type=self.norm_type, + use_graph=self.graph_harvesting) self._global_grad_norm = all_groups_norm assert all_groups_norm > 0. @@ -256,7 +258,8 @@ def step(self, closure=None): clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True), max_norm=self.clip_grad, global_norm=all_groups_norm, - mpu=self.mpu) + mpu=self.mpu, + use_graph=self.graph_harvesting) self.optimizer.step() @@ -281,23 +284,33 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg @torch.no_grad() def update_hp_grads(self, clear_lp_grads=False): + + def _update_hp_grads_func(clear_lp_grads=False): + for i, group in enumerate(self.bf16_groups): + for j, lp in enumerate(group): + if lp.grad is None: + continue + hp_grad = self.fp32_groups_gradients[i][j] + assert hp_grad is not None, \ + f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' + hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) + lp._hp_grad = hp_grad + self.fp32_groups_has_gradients[i][j] = True + # clear gradients + if clear_lp_grads: + lp.grad._zero() + + if self.graph_harvesting: + graph_process(False, _update_hp_grads_func, clear_lp_grads) + else: + _update_hp_grads_func(clear_lp_grads) + #cpu op for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if lp.grad is None: continue - - hp_grad = self.fp32_groups_gradients[i][j] - assert hp_grad is not None, \ - f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' - - hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) - lp._hp_grad = hp_grad self.fp32_groups_has_gradients[i][j] = True - # clear gradients - if clear_lp_grads: - lp.grad = None - @torch.no_grad() def get_grads_for_reduction(self): return self.fp32_groups_gradients_flat @@ -348,7 +361,9 @@ def clear_hp_grads(self): def clear_lp_grads(self): for group in self.bf16_groups: for param in group: - param.grad = None + if param.grad is not None: + # Using zero_() fixed memory address for graph replay + param.grad.zero_() def state_dict(self): state_dict = {} diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 737ae5795c67..8e3d45ee99b2 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -279,6 +279,10 @@ def get_gradient_clipping(param_dict): return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT) +def get_graph_harvesting(param_dict): + return get_scalar_param(param_dict, GRAPH_HARVESTING, GRAPH_HARVESTING_DEFAULT) + + def get_sparse_attention(param_dict): if SPARSE_ATTENTION in param_dict.keys(): sparsity = param_dict[SPARSE_ATTENTION] @@ -823,6 +827,7 @@ def _initialize_params(self, param_dict): self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) self.compression_config = get_compression_config(param_dict) + self.graph_harvesting = get_graph_harvesting(param_dict) self.optimizer_name = get_optimizer_name(param_dict) if (self.optimizer_name is not None and self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS): diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 96f2a38bd05c..82d8a0557a41 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -210,6 +210,18 @@ GRADIENT_CLIPPING = 'gradient_clipping' GRADIENT_CLIPPING_DEFAULT = 0. +######################################### +# Capture graph for short kernels sequences +######################################### +# Graph harvesting. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +GRAPH_HARVESTING_FORMAT = ''' +Graph harvesting should be enabled as: +"graph_harvesting": true +''' +GRAPH_HARVESTING = 'graph_harvesting' +GRAPH_HARVESTING_DEFAULT = False + ######################################### # Communication data type ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 26d37fdc79ac..96b9ffb19fa8 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -771,6 +771,9 @@ def zero_legacy_stage1(self): def zero_ignore_unused_parameters(self): return self._config.zero_config.ignore_unused_parameters + def graph_harvesting(self): + return self._config.graph_harvesting + def fp16_enabled(self): return self._config.fp16_enabled @@ -1466,7 +1469,8 @@ def _configure_bf16_optimizer(self, optimizer): allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.seq_data_parallel_group, timers=timers, - grad_acc_dtype=self.get_data_types()[1]) + grad_acc_dtype=self.get_data_types()[1], + graph_harvesting=self.graph_harvesting()) return optimizer diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index bc7a782e590c..82f200fccf9f 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -47,6 +47,27 @@ def __init__(self, params): self.param_groups.append({'params': params}) +graph_cache = {} + + +def graph_process(replay_first_step, func, *args, **kwargs): + # `func` should only contain operations on the GPU + # Please ensure that the memory address of the data required by 'func' remains constant + if func.__name__ not in graph_cache: + cuda_stream = get_accelerator().Stream() + cuda_stream.wait_stream(get_accelerator().current_stream()) + with get_accelerator().stream(cuda_stream): + func(*args, **kwargs) + get_accelerator().current_stream().wait_stream(cuda_stream) + graph_cache[func.__name__] = get_accelerator().create_graph() + with get_accelerator().capture_to_graph(graph_cache[func.__name__]): + func(*args, **kwargs) + if replay_first_step: + get_accelerator().replay_graph(graph_cache[func.__name__]) + else: + get_accelerator().replay_graph(graph_cache[func.__name__]) + + def noop_decorator(func): return func @@ -831,7 +852,7 @@ def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, ep return global_grad_norm -def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): +def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False): """Get norm of an iterable of tensors. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -845,7 +866,6 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): Returns: Total norm of the tensors (viewed as a single vector). """ - assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}' assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors' @@ -857,8 +877,24 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: - total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + if use_graph: + if 'norm_tensors_compute_buffer' not in graph_cache: + graph_cache['norm_tensors_compute_buffer'] = [t.data.float().norm(norm_type) for t in input_tensors] + compute_buffer = graph_cache['norm_tensors_compute_buffer'] + + def _norm_tensors(tensor_list, _compute_buffer, _norm_type): + for i, t in enumerate(tensor_list): + _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type) + if i != 0: + _compute_buffer[0].data.add_(_compute_buffer[i].data) + + graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type) + + total_norm = compute_buffer[0] + else: + total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) + + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach() if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item()**(1. / norm_type) @@ -869,7 +905,7 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): return total_norm -def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6): +def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6, use_graph=False): """Clip list of tensors by global norm. Args: input_tensors: List of tensors to be clipped @@ -880,14 +916,26 @@ def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, m float: the global norm """ if global_norm is None: - global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu) - + global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu, use_graph=use_graph) clip_coef = max_norm / (global_norm + eps) - if clip_coef < 1: - for t in input_tensors: - t.detach().mul_(clip_coef) + if use_graph: + def clip_tensors(_tensor_list, _clip_coef_tensor): + for t in _tensor_list: + t.detach().mul_(_clip_coef_tensor) + + if 'clip_coef_tensor' not in graph_cache: + # Alloc memory + graph_cache['clip_coef_tensor'] = torch.tensor(clip_coef, + dtype=torch.float32).to(get_accelerator().device_name()) + clip_coef_tensor = graph_cache['clip_coef_tensor'] + clip_coef_tensor.copy_(torch.tensor(clip_coef, dtype=torch.float32)) + graph_process(False, clip_tensors, input_tensors, clip_coef_tensor) + + else: + for t in input_tensors: + t.detach().mul_(clip_coef) return global_norm