Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Refactor NeighborSampler and expose fine-grained datapipes. #6983

Merged
merged 40 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b4c045f
prototyping
mfbalin Jan 19, 2024
a26e354
fix bug
mfbalin Jan 20, 2024
58a1190
remove print expressions, works now
mfbalin Jan 20, 2024
eedb8d1
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 23, 2024
3d2ed98
add tests
mfbalin Jan 24, 2024
105924a
use seeds_timestamp in preprocess
mfbalin Jan 24, 2024
bfb28ec
add docstring for linting
mfbalin Jan 24, 2024
e4becc9
fix linting
mfbalin Jan 24, 2024
428ff24
fix argument bug
mfbalin Jan 24, 2024
85b0601
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 24, 2024
866316e
fix the bug
mfbalin Jan 24, 2024
e2793fd
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 24, 2024
2473722
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 27, 2024
2d1dda9
address reviews
mfbalin Jan 29, 2024
fad7c50
add docstring to the new `MinibatchTransformer`.
mfbalin Jan 29, 2024
a8fdfc6
address review properly.
mfbalin Jan 29, 2024
933246f
remove unused `Mapper` import for linting.
mfbalin Jan 29, 2024
cd68728
NeighborSampler2 now derives from `MinibatchTransformer`.
mfbalin Jan 30, 2024
c3a903d
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 30, 2024
dcbfb4e
FInal refactoring of NeighborSampler.
mfbalin Jan 30, 2024
21fe633
Fix not only preprocess but also postprocess issue.
mfbalin Jan 30, 2024
29861f1
take back test changes.
mfbalin Jan 30, 2024
232f2f3
fix in_subgraph_sampler
mfbalin Jan 30, 2024
03bea25
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 30, 2024
86d9c43
add docstring for `append_sampling_step`.
mfbalin Jan 30, 2024
f995d20
Address reviews, minimize changes, keep API exactly the same.
mfbalin Jan 31, 2024
a64d34e
remove leftover changes.
mfbalin Jan 31, 2024
e46b8c7
minor change.
mfbalin Jan 31, 2024
8cc858c
Make the function into a proper one so that it can be pickled.
mfbalin Jan 31, 2024
02ca357
make the lambda into a proper function so that it can be pickled.
mfbalin Jan 31, 2024
19b4367
linting.
mfbalin Jan 31, 2024
67d6f71
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 31, 2024
144134c
final linting.
mfbalin Jan 31, 2024
96bac52
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 31, 2024
718cab8
Cleanup NeighborSampler as it does not need to store anything itself.
mfbalin Jan 31, 2024
ee3a7d7
linting
mfbalin Jan 31, 2024
1d906e7
address reviews by not passing sampler as string argument.
mfbalin Feb 1, 2024
6ab2e75
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Feb 1, 2024
6f880c0
Talk about `sampling_stages` in the SubgraphSampler API.
mfbalin Feb 1, 2024
5d907ee
add more documentation for `sampling_stages`.
mfbalin Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 98 additions & 48 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Neighbor subgraph samplers for GraphBolt."""

from functools import partial

import torch
from torch.utils.data import functional_datapipe

from ..internal import compact_csc_format, unique_and_compact_csc_formats
from ..minibatch_transformer import MiniBatchTransformer

from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import SampledSubgraphImpl
Expand All @@ -12,8 +15,66 @@
__all__ = ["NeighborSampler", "LayerNeighborSampler"]


@functional_datapipe("sample_per_layer")
class SamplePerLayer(MiniBatchTransformer):
"""Sample neighbor edges from a graph for a single layer."""

def __init__(self, datapipe, sampler, fanout, replace, prob_name):
super().__init__(datapipe, self._sample_per_layer)
self.sampler = sampler
self.fanout = fanout
self.replace = replace
self.prob_name = prob_name

def _sample_per_layer(self, minibatch):
subgraph = self.sampler(
minibatch._seed_nodes, self.fanout, self.replace, self.prob_name
)
minibatch.sampled_subgraphs.insert(0, subgraph)
return minibatch


@functional_datapipe("compact_per_layer")
class CompactPerLayer(MiniBatchTransformer):
"""Compact the sampled edges for a single layer."""

def __init__(self, datapipe, deduplicate):
super().__init__(datapipe, self._compact_per_layer)
self.deduplicate = deduplicate

def _compact_per_layer(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0]
seeds = minibatch._seed_nodes
if self.deduplicate:
(
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_csc_format,
) = compact_csc_format(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
minibatch._seed_nodes = original_row_node_ids
minibatch.sampled_subgraphs[0] = subgraph
return minibatch


@functional_datapipe("sample_neighbor")
class NeighborSampler(SubgraphSampler):
# pylint: disable=abstract-method
"""Sample neighbor edges from a graph and return a subgraph.

Functional name: :obj:`sample_neighbor`.
Expand Down Expand Up @@ -95,6 +156,7 @@ class NeighborSampler(SubgraphSampler):
)]
"""

# pylint: disable=useless-super-delegation
def __init__(
self,
datapipe,
Expand All @@ -103,26 +165,17 @@ def __init__(
replace=False,
prob_name=None,
deduplicate=True,
sampler="sample_neighbors",
mfbalin marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(datapipe)
self.graph = graph
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
self.replace = replace
self.prob_name = prob_name
self.deduplicate = deduplicate
self.sampler = graph.sample_neighbors
super().__init__(
datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
)

def sample_subgraphs(self, seeds, seeds_timestamp):
subgraphs = []
num_layers = len(self.fanouts)
def _prepare(self, node_type_to_id, minibatch):
seeds = minibatch._seed_nodes
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(self.graph.node_type_to_id.keys())
ntypes = list(node_type_to_id.keys())
# Loop over different seeds to extract the device they are on.
device = None
dtype = None
Expand All @@ -134,42 +187,38 @@ def sample_subgraphs(self, seeds, seeds_timestamp):
seeds = {
ntype: seeds.get(ntype, default_tensor) for ntype in ntypes
}
for hop in range(num_layers):
subgraph = self.sampler(
seeds,
self.fanouts[hop],
self.replace,
self.prob_name,
minibatch._seed_nodes = seeds
minibatch.sampled_subgraphs = []
return minibatch

@staticmethod
def _set_input_nodes(minibatch):
minibatch.input_nodes = minibatch._seed_nodes
return minibatch

# pylint: disable=arguments-differ
def sampling_stages(
mfbalin marked this conversation as resolved.
Show resolved Hide resolved
self, datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
):
datapipe = datapipe.transform(
partial(self._prepare, graph.node_type_to_id)
)
sampler = getattr(graph, sampler)
mfbalin marked this conversation as resolved.
Show resolved Hide resolved
for fanout in reversed(fanouts):
# Convert fanout to tensor.
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
datapipe = datapipe.sample_per_layer(
sampler, fanout, replace, prob_name
)
if self.deduplicate:
(
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_csc_format,
) = compact_csc_format(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
subgraphs.insert(0, subgraph)
seeds = original_row_node_ids
return seeds, subgraphs
datapipe = datapipe.compact_per_layer(deduplicate)

return datapipe.transform(self._set_input_nodes)


@functional_datapipe("sample_layer_neighbor")
class LayerNeighborSampler(NeighborSampler):
# pylint: disable=abstract-method
"""Sample layer neighbor edges from a graph and return a subgraph.

Functional name: :obj:`sample_layer_neighbor`.
Expand Down Expand Up @@ -272,6 +321,7 @@ def __init__(
replace=False,
prob_name=None,
deduplicate=True,
sampler="sample_layer_neighbors",
):
super().__init__(
datapipe,
Expand All @@ -280,5 +330,5 @@ def __init__(
replace,
prob_name,
deduplicate,
sampler,
)
self.sampler = graph.sample_layer_neighbors
48 changes: 40 additions & 8 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,43 @@ class SubgraphSampler(MiniBatchTransformer):
----------
datapipe : DataPipe
The datapipe.
args : Non-Keyword Arguments
Arguments to be passed into sampling_stages.
kwargs : Keyword Arguments
Arguments to be passed into sampling_stages.
"""

def __init__(
self,
datapipe,
*args,
**kwargs,
):
super().__init__(datapipe, self._sample)
datapipe = datapipe.transform(self._preprocess)
datapipe = self.sampling_stages(datapipe, *args, **kwargs)
datapipe = datapipe.transform(self._postprocess)
super().__init__(datapipe, self._identity)

def _sample(self, minibatch):
@staticmethod
def _identity(minibatch):
return minibatch

@staticmethod
def _postprocess(minibatch):
delattr(minibatch, "_seed_nodes")
delattr(minibatch, "_seeds_timestamp")
return minibatch

@staticmethod
def _preprocess(minibatch):
if minibatch.node_pairs is not None:
(
seeds,
seeds_timestamp,
minibatch.compacted_node_pairs,
minibatch.compacted_negative_srcs,
minibatch.compacted_negative_dsts,
) = self._node_pairs_preprocess(minibatch)
) = SubgraphSampler._node_pairs_preprocess(minibatch)
elif minibatch.seed_nodes is not None:
seeds = minibatch.seed_nodes
seeds_timestamp = (
Expand All @@ -55,13 +75,12 @@ def _sample(self, minibatch):
f"Invalid minibatch {minibatch}: Either `node_pairs` or "
"`seed_nodes` should have a value."
)
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self.sample_subgraphs(seeds, seeds_timestamp)
minibatch._seed_nodes = seeds
minibatch._seeds_timestamp = seeds_timestamp
return minibatch

def _node_pairs_preprocess(self, minibatch):
@staticmethod
def _node_pairs_preprocess(minibatch):
use_timestamp = hasattr(minibatch, "timestamp")
node_pairs = minibatch.node_pairs
neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts
Expand Down Expand Up @@ -191,6 +210,19 @@ def _node_pairs_preprocess(self, minibatch):
compacted_negative_dsts if has_neg_dst else None,
)

def _sample(self, minibatch):
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self.sample_subgraphs(
minibatch._seed_nodes, minibatch._seeds_timestamp
)
return minibatch

def sampling_stages(self, datapipe):
"""The sampling stages are defined here by chaining to the datapipe."""
Rhett-Ying marked this conversation as resolved.
Show resolved Hide resolved
return datapipe.transform(self._sample)

def sample_subgraphs(self, seeds, seeds_timestamp):
Rhett-Ying marked this conversation as resolved.
Show resolved Hide resolved
"""Sample subgraphs from the given seeds, possibly with temporal constraints.

Expand Down
Loading