From e4c22121aa358294e4f7a6601df085ea00e08bec Mon Sep 17 00:00:00 2001 From: Xu Xinran Date: Mon, 29 May 2023 08:24:24 +0000 Subject: [PATCH] fix-ipg-buffer-data-race --- deepspeed/runtime/zero/stage_1_and_2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 8f15b87af0e2..8345a0bc1c66 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -375,7 +375,7 @@ def __init__(self, self.reduce_bucket_size = int(reduce_bucket_size) self.allgather_bucket_size = int(allgather_bucket_size) - self.reduction_event = get_accelerator().Event(enable_timing=False, blocking=False) + self.reduction_events = [get_accelerator().Event(enable_timing=False, blocking=False) for _ in range(2)] self.reduction_stream = get_accelerator().Stream() self.cpu_computation_stream = get_accelerator().Stream() self.copy_grad_stream = get_accelerator().Stream() @@ -835,6 +835,8 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel()) self.reduce_ipg_grads() if self.contiguous_gradients and self.overlap_comm: + # Wait until buffer reader finish + self.reduction_events[1 - self.ipg_index].wait(get_accelerator().current_stream()) # Swap ipg_index between 0 and 1 self.ipg_index = 1 - self.ipg_index self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel()) @@ -1246,6 +1248,7 @@ def reduce_ipg_grads(self): else: # zero stage 1 - partition only optimizer state if self.contiguous_gradients and self.is_param_in_current_partition[param_id]: self.copy_grads_in_partition(param) + self.reduction_events[self.ipg_index].record(stream) self.grads_in_ipg_bucket = [] self.params_in_ipg_bucket = []