diff --git a/msccl/language/collectives.py b/msccl/language/collectives.py index c8434da..2e4fa73 100755 --- a/msccl/language/collectives.py +++ b/msccl/language/collectives.py @@ -73,9 +73,12 @@ def check(self, prog): class AllGather(Collective): - def __init__(self, num_ranks, chunk_factor, inplace): + def __init__(self, num_ranks, chunk_factor, inplace, create_all_chunks=False): Collective.__init__(self, num_ranks, chunk_factor, inplace) self.name = "allgather" + # This flag is a temporary solution, which initialize all the chuncks only for inputbuffer + # In this future we need to remove this flag and always initialize all the chunks + self.create_all_chunks = create_all_chunks # Initializes input buffer for an allgather def init_buffers(self): @@ -84,8 +87,13 @@ def init_buffers(self): # Inplace AllGather only uses the output buffer for r in range(self.num_ranks): output_buffer = [None] * (self.num_ranks * self.chunk_factor) - for ch in range(self.chunk_factor): - output_buffer[r * self.chunk_factor + ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch) + if not self.create_all_chunks: + for ch in range(self.chunk_factor): + output_buffer[r * self.chunk_factor + ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch) + else: + for rank in range(self.num_ranks): + for ch in range(self.chunk_factor): + output_buffer[rank * self.chunk_factor + ch] = Chunk(rank, ch, -1, rank * self.chunk_factor + ch) buffers = { Buffer.input: output_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor], Buffer.output: output_buffer,