Skip to content

Commit

Permalink
Merge pull request bytedance#97 from bytedance/add_barrier_to_micro_perf
Browse files Browse the repository at this point in the history
add barrier for ccl ops in micro_perf.
  • Loading branch information
suisiyuan authored Aug 23, 2024
2 parents 572c388 + 2a49ac9 commit b5ac619
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
5 changes: 4 additions & 1 deletion byte_micro_perf/backends/AMD/backend_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,7 @@ def setup_2d_group(self):
torch.distributed.barrier()

def destroy_process_group(self):
dist.destroy_process_group()
dist.destroy_process_group()

def barier(self):
dist.barrier(self.group)
5 changes: 4 additions & 1 deletion byte_micro_perf/backends/GPU/backend_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,7 @@ def setup_2d_group(self):
torch.distributed.barrier()

def destroy_process_group(self):
dist.destroy_process_group()
dist.destroy_process_group()

def barier(self):
dist.barrier(self.group)
14 changes: 13 additions & 1 deletion byte_micro_perf/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def setup_2d_group(self):
def destroy_process_group(self):
pass

@abstractmethod
def barier(self):
pass


# communication ops
Expand Down Expand Up @@ -229,6 +232,11 @@ def perf(self, input_shapes: List[List[int]], dtype):
for _ in range(num_warm_up):
self._run_operation(self.op, tensor_list[0])


# ccl ops need barrier
if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]:
self.barier()

# test perf
num_test_perf = 5
self.device_synchronize()
Expand All @@ -241,7 +249,6 @@ def perf(self, input_shapes: List[List[int]], dtype):
self.device_synchronize()
end_time = time.perf_counter_ns()


prefer_iterations = self.iterations
max_perf_seconds = 10.0
op_duration = (end_time - start_time) / num_test_perf / 1e9
Expand All @@ -250,6 +257,11 @@ def perf(self, input_shapes: List[List[int]], dtype):
else:
prefer_iterations = min(max(int(max_perf_seconds // op_duration), 10), self.iterations)


# ccl ops need barrier
if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]:
self.barier()

# perf
self.device_synchronize()
start_time = time.perf_counter_ns()
Expand Down

0 comments on commit b5ac619

Please sign in to comment.