diff --git a/python/dgl/graphbolt/item_sampler.py b/python/dgl/graphbolt/item_sampler.py index e2e646f9bb96..57ebd684cf2b 100644 --- a/python/dgl/graphbolt/item_sampler.py +++ b/python/dgl/graphbolt/item_sampler.py @@ -100,7 +100,21 @@ class ItemShufflerAndBatcher: drop_last : bool Option to drop the last batch if it's not full. buffer_size : int - The size of the buffer to store items sliced from the :class:`ItemSet`. + The size of the buffer to store items sliced from the :class:`ItemSet` + or :class:`ItemSetDict`. + distributed : bool + Option to apply on :class:`DistributedItemSampler`. + drop_uneven_inputs : bool + Option to make sure the numbers of batches for each replica are the + same. Applies only when `distributed` is True. + world_size : int + The number of model replicas that will be created during Distributed + Data Parallel (DDP) training. It should be the same as the real world + size, otherwise it could cause errors. Applies only when `distributed` + is True. + rank : int + The rank of the current replica. Applies only when `distributed` is + True. """ def __init__( @@ -109,7 +123,7 @@ def __init__( shuffle: bool, batch_size: int, drop_last: bool, - buffer_size: Optional[int] = 10 * 1000, + buffer_size: int, distributed: Optional[bool] = False, drop_uneven_inputs: Optional[bool] = False, world_size: Optional[int] = 1, @@ -119,7 +133,7 @@ def __init__( self._shuffle = shuffle self._batch_size = batch_size self._drop_last = drop_last - self._buffer_size = max(buffer_size, 20 * batch_size) + self._buffer_size = buffer_size # Round up the buffer size to the nearest multiple of batch size. self._buffer_size = ( (self._buffer_size + batch_size - 1) // batch_size * batch_size @@ -240,6 +254,24 @@ class ItemSampler(IterDataPipe): Option to drop the last batch if it's not full. shuffle : bool Option to shuffle before sample. + use_indexing : bool + Option to use indexing to slice items from the item set. This is an + optimization to avoid time-consuming iteration over the item set. If + the item set does not support indexing, this option will be disabled + automatically. If the item set supports indexing but the user wants to + disable it, this option can be set to False. By default, it is set to + True. + buffer_size : int + The size of the buffer to store items sliced from the :class:`ItemSet` + or :class:`ItemSetDict`. By default, it is set to -1, which means the + buffer size will be set as the total number of items in the item set if + indexing is supported. If indexing is not supported, it is set to 10 * + batch size. If the item set is too large, it is recommended to set a + smaller buffer size to avoid out of memory error. As items are shuffled + within each buffer, a smaller buffer size may incur less randomness and + such less randomness can further affect the training performance such as + convergence speed and accuracy. Therefore, it is recommended to set a + larger buffer size if possible. Examples -------- @@ -429,23 +461,35 @@ def __init__( # [TODO][Rui] For now, it's a temporary knob to disable indexing. In # the future, we will enable indexing for all the item sets. use_indexing: Optional[bool] = True, + buffer_size: Optional[int] = -1, ) -> None: super().__init__() self._names = item_set.names # Check if the item set supports indexing. + indexable = True try: item_set[0] except TypeError: - use_indexing = False - self._use_indexing = use_indexing + indexable = False + self._use_indexing = use_indexing and indexable self._item_set = ( item_set if self._use_indexing else IterableWrapper(item_set) ) + if buffer_size == -1: + if indexable: + # Set the buffer size to the total number of items in the item + # set if indexing is supported and the buffer size is not + # specified. + buffer_size = len(self._item_set) + else: + # Set the buffer size to 10 * batch size if indexing is not + # supported and the buffer size is not specified. + buffer_size = 10 * batch_size + self._buffer_size = buffer_size self._batch_size = batch_size self._minibatcher = minibatcher self._drop_last = drop_last self._shuffle = shuffle - self._use_indexing = use_indexing self._distributed = False self._drop_uneven_inputs = False self._world_size = None @@ -454,11 +498,7 @@ def __init__( def _organize_items(self, data_pipe) -> None: # Shuffle before batch. if self._shuffle: - # `torchdata.datapipes.iter.Shuffler` works with stream too. - # To ensure randomness, make sure the buffer size is at least 10 - # times the batch size. - buffer_size = max(10000, 10 * self._batch_size) - data_pipe = data_pipe.shuffle(buffer_size=buffer_size) + data_pipe = data_pipe.shuffle(buffer_size=self._buffer_size) # Batch. data_pipe = data_pipe.batch( @@ -495,6 +535,7 @@ def __iter__(self) -> Iterator: self._shuffle, self._batch_size, self._drop_last, + self._buffer_size, distributed=self._distributed, drop_uneven_inputs=self._drop_uneven_inputs, world_size=self._world_size, @@ -563,6 +604,16 @@ class DistributedItemSampler(ItemSampler): https://pytorch.org/tutorials/advanced/generic_join.html. However, this option can be used if the Join Context Manager is not helpful for any reason. + buffer_size : int + The size of the buffer to store items sliced from the :class:`ItemSet` + or :class:`ItemSetDict`. By default, it is set to -1, which means the + buffer size will be set as the total number of items in the item set. + If the item set is too large, it is recommended to set a smaller buffer + size to avoid out of memory error. As items are shuffled within each + buffer, a smaller buffer size may incur less randomness and such less + randomness can further affect the training performance such as + convergence speed and accuracy. Therefore, it is recommended to set a + larger buffer size if possible. Examples -------- @@ -667,6 +718,7 @@ def __init__( drop_last: Optional[bool] = False, shuffle: Optional[bool] = False, drop_uneven_inputs: Optional[bool] = False, + buffer_size: Optional[int] = -1, ) -> None: super().__init__( item_set, @@ -675,6 +727,7 @@ def __init__( drop_last, shuffle, use_indexing=True, + buffer_size=buffer_size, ) self._distributed = True self._drop_uneven_inputs = drop_uneven_inputs