Skip to content

Commit

Permalink
FInal refactoring of NeighborSampler.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 30, 2024
1 parent c3a903d commit dcbfb4e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 414 deletions.
178 changes: 10 additions & 168 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .sampled_subgraph_impl import SampledSubgraphImpl


__all__ = ["NeighborSampler", "LayerNeighborSampler", "NeighborSampler2"]
__all__ = ["NeighborSampler", "LayerNeighborSampler"]


@functional_datapipe("sample_per_layer")
Expand Down Expand Up @@ -70,8 +70,8 @@ def _compact_per_layer(self, minibatch):
return minibatch


@functional_datapipe("sample_neighbor2")
class NeighborSampler2(MiniBatchTransformer):
@functional_datapipe("sample_neighbor")
class NeighborSampler(SubgraphSampler):
"""Sample neighbor edges from a graph and return a subgraph.
Functional name: :obj:`sample_neighbor`.
Expand Down Expand Up @@ -163,11 +163,10 @@ def __init__(
deduplicate=True,
sampler="sample_neighbors",
):
super().__init__(datapipe)
self.graph = graph
datapipe = datapipe.sample_subgraph_preprocess()

def prepare(minibatch_and_seeds_timestamp):
minibatch, _ = minibatch_and_seeds_timestamp
def prepare(minibatch):
seeds = minibatch.input_nodes
# Enrich seeds with all node types.
if isinstance(seeds, dict):
Expand All @@ -187,173 +186,16 @@ def prepare(minibatch_and_seeds_timestamp):
minibatch.sampled_subgraphs = []
return minibatch

datapipe = datapipe.transform(prepare)
self.append_sampling_step(MiniBatchTransformer, prepare)
sampler = getattr(graph, sampler)
for fanout in reversed(fanouts):
# Convert fanout to tensor.
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
datapipe = datapipe.sample_per_layer(
sampler, fanout, replace, prob_name
)
datapipe = datapipe.compact_per_layer(deduplicate)
super().__init__(datapipe, lambda minibatch: minibatch)


@functional_datapipe("sample_neighbor")
class NeighborSampler(SubgraphSampler):
"""Sample neighbor edges from a graph and return a subgraph.
Functional name: :obj:`sample_neighbor`.
Neighbor sampler is responsible for sampling a subgraph from given data. It
returns an induced subgraph along with compacted information. In the
context of a node classification task, the neighbor sampler directly
utilizes the nodes provided as seed nodes. However, in scenarios involving
link prediction, the process needs another pre-peocess operation. That is,
gathering unique nodes from the given node pairs, encompassing both
positive and negative node pairs, and employs these nodes as the seed nodes
for subsequent steps.
Parameters
----------
datapipe : DataPipe
The datapipe.
graph : FusedCSCSamplingGraph
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor] or list[int]
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
Note: The fanout order is from the outermost layer to innermost layer.
For example, the fanout '[15, 10, 5]' means that 15 to the outermost
layer, 10 to the intermediate layer and 5 corresponds to the innermost
layer.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
prob_name: str, optional
The name of an edge attribute used as the weights of sampling for
each node. This attribute tensor should contain (unnormalized)
probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges.
deduplicate: bool
Boolean indicating whether seeds between hops will be deduplicated.
If True, the same elements in seeds will be deleted to only one.
Otherwise, the same elements will be remained.
Examples
-------
>>> import torch
>>> import dgl.graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> datapipe = gb.ItemSampler(item_set, batch_size=1)
>>> datapipe = datapipe.sample_uniform_negative(graph, 2)
>>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15])
>>> next(iter(datapipe)).sampled_subgraphs
[SampledSubgraphImpl(sampled_csc=CSCFormatBase(
indptr=tensor([0, 2, 4, 5, 6, 7, 8]),
indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]),
),
original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(
indptr=tensor([0, 2, 4, 5, 6, 7, 8]),
indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]),
),
original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(
indptr=tensor([0, 2, 4, 5, 6]),
indices=tensor([1, 4, 0, 5, 5, 3]),
),
original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 4, 5]),
)]
"""

def __init__(
self,
datapipe,
graph,
fanouts,
replace=False,
prob_name=None,
deduplicate=True,
):
super().__init__(datapipe)
self.graph = graph
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
self.replace = replace
self.prob_name = prob_name
self.deduplicate = deduplicate
self.sampler = graph.sample_neighbors

def sample_subgraphs(self, seeds, seeds_timestamp):
subgraphs = []
num_layers = len(self.fanouts)
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(self.graph.node_type_to_id.keys())
# Loop over different seeds to extract the device they are on.
device = None
dtype = None
for _, seed in seeds.items():
device = seed.device
dtype = seed.dtype
break
default_tensor = torch.tensor([], dtype=dtype, device=device)
seeds = {
ntype: seeds.get(ntype, default_tensor) for ntype in ntypes
}
for hop in range(num_layers):
subgraph = self.sampler(
seeds,
self.fanouts[hop],
self.replace,
self.prob_name,
self.append_sampling_step(
SamplePerLayer, sampler, fanout, replace, prob_name
)
if self.deduplicate:
(
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_csc_format,
) = compact_csc_format(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
subgraphs.insert(0, subgraph)
seeds = original_row_node_ids
return seeds, subgraphs
self.append_sampling_step(CompactPerLayer, deduplicate)


@functional_datapipe("sample_layer_neighbor")
Expand Down Expand Up @@ -468,5 +310,5 @@ def __init__(
replace,
prob_name,
deduplicate,
sampler="sample_layer_neighbors",
)
self.sampler = graph.sample_layer_neighbors
10 changes: 8 additions & 2 deletions python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.utils.data import functional_datapipe

from ..internal import compact_csc_format
from ..minibatch_transformer import MiniBatchTransformer

from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import SampledSubgraphImpl
Expand Down Expand Up @@ -88,8 +89,11 @@ def __init__(
self.node_timestamp_attr_name = node_timestamp_attr_name
self.edge_timestamp_attr_name = edge_timestamp_attr_name
self.sampler = graph.temporal_sample_neighbors
self.append_sampling_step(MiniBatchTransformer, self._sample_subgraphs)

def sample_subgraphs(self, seeds, seeds_timestamp):
def _sample_subgraphs(self, minibatch):
seeds = minibatch.input_nodes
seeds_timestamp = minibatch.seeds_timestamp
assert (
seeds_timestamp is not None
), "seeds_timestamp must be provided for temporal neighbor sampling."
Expand Down Expand Up @@ -132,4 +136,6 @@ def sample_subgraphs(self, seeds, seeds_timestamp):
subgraphs.insert(0, subgraph)
seeds = original_row_node_ids
seeds_timestamp = row_timestamps
return seeds, subgraphs
minibatch.input_nodes = seeds
minibatch.sampled_subgraphs = subgraphs
return minibatch
Loading

0 comments on commit dcbfb4e

Please sign in to comment.