From f515104e95f3b337f0cb59bb9e87439b44077799 Mon Sep 17 00:00:00 2001 From: Hyeonseung Lee Date: Fri, 22 Nov 2024 02:37:29 +0900 Subject: [PATCH] Removes unnecessary cloning (#6761) `clone_tensors_for_torch_save()` function: When the `item.device` is different from `device` input, `tensor.clone()` is not actually required because `to()` function also clones the original tensor. +) I observed memory bloat under following conditions: * Training a Whisper model w/ `transformers` framework with `ZeRO-0` and `ZeRO-1` configuration. * Memory bloating can be observed every time the model state_dict is cloned using `clone_tensors_for_torch_save()` After I removed the unnecessary `clone()`, seems like the problem is solved. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- deepspeed/checkpoint/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deepspeed/checkpoint/utils.py b/deepspeed/checkpoint/utils.py index c305e8884e83..5964da00728e 100644 --- a/deepspeed/checkpoint/utils.py +++ b/deepspeed/checkpoint/utils.py @@ -51,7 +51,12 @@ def clone_tensors_for_torch_save(item, device=torch.device('cpu')): - copy of ``item`` with cloned tensors on target device """ if torch.is_tensor(item): - return item.detach().clone().to(device) + if type(device) is str: + device = torch.device(device) + if device == item.device: + return item.detach().clone() + else: + return item.detach().to(device) elif isinstance(item, list): return [clone_tensors_for_torch_save(v, device) for v in item] elif isinstance(item, tuple):