diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index f89e43e0c8d9..b0dd9daccfaf 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -1,6 +1,6 @@ """Graph Bolt DataLoaders""" -from queue import Queue +from collections import deque import torch import torch.utils.data @@ -69,18 +69,18 @@ def __init__(self, datapipe, buffer_size=1): raise ValueError( "'buffer_size' is required to be a positive integer." ) - self.buffer = Queue(buffer_size) + self.buffer = deque(maxlen=buffer_size) def __iter__(self): for data in self.datapipe: - if not self.buffer.full(): - self.buffer.put(data) + if len(self.buffer) < self.buffer.maxlen: + self.buffer.append(data) else: - return_data = self.buffer.get() - self.buffer.put(data) + return_data = self.buffer.popleft() + self.buffer.append(data) yield return_data - while not self.buffer.empty(): - yield self.buffer.get() + while len(self.buffer) > 0: + yield self.buffer.popleft() class Awaiter(dp.iter.IterDataPipe):