Skip to content

Commit

Permalink
[GraphBolt] Refactor NeighborSampler and expose fine-grained datapipe…
Browse files Browse the repository at this point in the history
…s. (#6983)
  • Loading branch information
mfbalin authored Feb 1, 2024
1 parent e602ab1 commit 50eb101
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 57 deletions.
146 changes: 98 additions & 48 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Neighbor subgraph samplers for GraphBolt."""

from functools import partial

import torch
from torch.utils.data import functional_datapipe

from ..internal import compact_csc_format, unique_and_compact_csc_formats
from ..minibatch_transformer import MiniBatchTransformer

from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import SampledSubgraphImpl
Expand All @@ -12,8 +15,66 @@
__all__ = ["NeighborSampler", "LayerNeighborSampler"]


@functional_datapipe("sample_per_layer")
class SamplePerLayer(MiniBatchTransformer):
"""Sample neighbor edges from a graph for a single layer."""

def __init__(self, datapipe, sampler, fanout, replace, prob_name):
super().__init__(datapipe, self._sample_per_layer)
self.sampler = sampler
self.fanout = fanout
self.replace = replace
self.prob_name = prob_name

def _sample_per_layer(self, minibatch):
subgraph = self.sampler(
minibatch._seed_nodes, self.fanout, self.replace, self.prob_name
)
minibatch.sampled_subgraphs.insert(0, subgraph)
return minibatch


@functional_datapipe("compact_per_layer")
class CompactPerLayer(MiniBatchTransformer):
"""Compact the sampled edges for a single layer."""

def __init__(self, datapipe, deduplicate):
super().__init__(datapipe, self._compact_per_layer)
self.deduplicate = deduplicate

def _compact_per_layer(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0]
seeds = minibatch._seed_nodes
if self.deduplicate:
(
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_csc_format,
) = compact_csc_format(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
minibatch._seed_nodes = original_row_node_ids
minibatch.sampled_subgraphs[0] = subgraph
return minibatch


@functional_datapipe("sample_neighbor")
class NeighborSampler(SubgraphSampler):
# pylint: disable=abstract-method
"""Sample neighbor edges from a graph and return a subgraph.
Functional name: :obj:`sample_neighbor`.
Expand Down Expand Up @@ -95,6 +156,7 @@ class NeighborSampler(SubgraphSampler):
)]
"""

# pylint: disable=useless-super-delegation
def __init__(
self,
datapipe,
Expand All @@ -103,26 +165,19 @@ def __init__(
replace=False,
prob_name=None,
deduplicate=True,
sampler=None,
):
super().__init__(datapipe)
self.graph = graph
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
self.replace = replace
self.prob_name = prob_name
self.deduplicate = deduplicate
self.sampler = graph.sample_neighbors
if sampler is None:
sampler = graph.sample_neighbors
super().__init__(
datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
)

def sample_subgraphs(self, seeds, seeds_timestamp):
subgraphs = []
num_layers = len(self.fanouts)
def _prepare(self, node_type_to_id, minibatch):
seeds = minibatch._seed_nodes
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(self.graph.node_type_to_id.keys())
ntypes = list(node_type_to_id.keys())
# Loop over different seeds to extract the device they are on.
device = None
dtype = None
Expand All @@ -134,42 +189,37 @@ def sample_subgraphs(self, seeds, seeds_timestamp):
seeds = {
ntype: seeds.get(ntype, default_tensor) for ntype in ntypes
}
for hop in range(num_layers):
subgraph = self.sampler(
seeds,
self.fanouts[hop],
self.replace,
self.prob_name,
minibatch._seed_nodes = seeds
minibatch.sampled_subgraphs = []
return minibatch

@staticmethod
def _set_input_nodes(minibatch):
minibatch.input_nodes = minibatch._seed_nodes
return minibatch

# pylint: disable=arguments-differ
def sampling_stages(
self, datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
):
datapipe = datapipe.transform(
partial(self._prepare, graph.node_type_to_id)
)
for fanout in reversed(fanouts):
# Convert fanout to tensor.
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
datapipe = datapipe.sample_per_layer(
sampler, fanout, replace, prob_name
)
if self.deduplicate:
(
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_csc_format,
) = compact_csc_format(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
subgraphs.insert(0, subgraph)
seeds = original_row_node_ids
return seeds, subgraphs
datapipe = datapipe.compact_per_layer(deduplicate)

return datapipe.transform(self._set_input_nodes)


@functional_datapipe("sample_layer_neighbor")
class LayerNeighborSampler(NeighborSampler):
# pylint: disable=abstract-method
"""Sample layer neighbor edges from a graph and return a subgraph.
Functional name: :obj:`sample_layer_neighbor`.
Expand Down Expand Up @@ -280,5 +330,5 @@ def __init__(
replace,
prob_name,
deduplicate,
graph.sample_layer_neighbors,
)
self.sampler = graph.sample_layer_neighbors
57 changes: 48 additions & 9 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,52 @@ class SubgraphSampler(MiniBatchTransformer):
Functional name: :obj:`sample_subgraph`.
This class is the base class of all subgraph samplers. Any subclass of
SubgraphSampler should implement the :meth:`sample_subgraphs` method.
SubgraphSampler should implement either the :meth:`sample_subgraphs` method
or the :meth:`sampling_stages` method to define the fine-grained sampling
stages to take advantage of optimizations provided by the GraphBolt
DataLoader.
Parameters
----------
datapipe : DataPipe
The datapipe.
args : Non-Keyword Arguments
Arguments to be passed into sampling_stages.
kwargs : Keyword Arguments
Arguments to be passed into sampling_stages.
"""

def __init__(
self,
datapipe,
*args,
**kwargs,
):
super().__init__(datapipe, self._sample)
datapipe = datapipe.transform(self._preprocess)
datapipe = self.sampling_stages(datapipe, *args, **kwargs)
datapipe = datapipe.transform(self._postprocess)
super().__init__(datapipe, self._identity)

def _sample(self, minibatch):
@staticmethod
def _identity(minibatch):
return minibatch

@staticmethod
def _postprocess(minibatch):
delattr(minibatch, "_seed_nodes")
delattr(minibatch, "_seeds_timestamp")
return minibatch

@staticmethod
def _preprocess(minibatch):
if minibatch.node_pairs is not None:
(
seeds,
seeds_timestamp,
minibatch.compacted_node_pairs,
minibatch.compacted_negative_srcs,
minibatch.compacted_negative_dsts,
) = self._node_pairs_preprocess(minibatch)
) = SubgraphSampler._node_pairs_preprocess(minibatch)
elif minibatch.seed_nodes is not None:
seeds = minibatch.seed_nodes
seeds_timestamp = (
Expand All @@ -55,13 +78,12 @@ def _sample(self, minibatch):
f"Invalid minibatch {minibatch}: Either `node_pairs` or "
"`seed_nodes` should have a value."
)
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self.sample_subgraphs(seeds, seeds_timestamp)
minibatch._seed_nodes = seeds
minibatch._seeds_timestamp = seeds_timestamp
return minibatch

def _node_pairs_preprocess(self, minibatch):
@staticmethod
def _node_pairs_preprocess(minibatch):
use_timestamp = hasattr(minibatch, "timestamp")
node_pairs = minibatch.node_pairs
neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts
Expand Down Expand Up @@ -191,6 +213,23 @@ def _node_pairs_preprocess(self, minibatch):
compacted_negative_dsts if has_neg_dst else None,
)

def _sample(self, minibatch):
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self.sample_subgraphs(
minibatch._seed_nodes, minibatch._seeds_timestamp
)
return minibatch

def sampling_stages(self, datapipe):
"""The sampling stages are defined here by chaining to the datapipe. The
default implementation expects :meth:`sample_subgraphs` to be
implemented. To define fine-grained stages, this method should be
overridden.
"""
return datapipe.transform(self._sample)

def sample_subgraphs(self, seeds, seeds_timestamp):
"""Sample subgraphs from the given seeds, possibly with temporal constraints.
Expand Down

0 comments on commit 50eb101

Please sign in to comment.