Skip to content

Commit

Permalink
[Graphbolt] Rewrite to_dgl to multiple get functions (#6735)
Browse files Browse the repository at this point in the history
  • Loading branch information
RamonZhou authored Dec 15, 2023
1 parent 70fdb69 commit cad7cae
Show file tree
Hide file tree
Showing 10 changed files with 562 additions and 424 deletions.
14 changes: 4 additions & 10 deletions examples/sampling/graphbolt/quickstart/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,8 @@ def evaluate(model, dataset, device):
logits = []
labels = []
for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()

# Unpack MiniBatch.
compacted_pairs, label = to_binary_link_dgl_computing_pack(data)
# Get node pairs with labels for loss calculation.
compacted_pairs, label = data.node_pairs_with_labels

# The features of sampled nodes.
x = data.node_features["feat"]
Expand Down Expand Up @@ -140,11 +137,8 @@ def train(model, dataset, device):
# mini-batches.
########################################################################
for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()

# Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
# Get node pairs with labels for loss calculation.
compacted_pairs, labels = data.node_pairs_with_labels

# The features of sampled nodes.
x = data.node_features["feat"]
Expand Down
4 changes: 0 additions & 4 deletions examples/sampling/graphbolt/quickstart/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def evaluate(model, dataset, itemset, device):
dataloader = create_dataloader(dataset, itemset, device)

for step, data in enumerate(dataloader):
data = data.to_dgl()
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(data.blocks, x))
Expand All @@ -84,9 +83,6 @@ def train(model, dataset, device):
# mini-batches.
########################################################################
for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()

# The features of sampled nodes.
x = data.node_features["feat"]

Expand Down
229 changes: 135 additions & 94 deletions python/dgl/graphbolt/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,10 @@ def set_edge_features(
"""Set edge features."""
self.edge_features = edge_features

def _to_dgl_blocks(self):
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing
a graphical structure and ID mappings.
@property
def blocks(self):
"""Extracts DGL blocks from `MiniBatch` to construct a graphical
structure and ID mappings.
"""
if not self.sampled_subgraphs:
return None
Expand Down Expand Up @@ -459,98 +460,135 @@ def _to_dgl_blocks(self):
block.edata[dgl.EID] = subgraph.original_edge_ids
return blocks

def to_dgl(self):
"""Converting a `MiniBatch` into a DGL MiniBatch that contains
everything necessary for computation."
@property
def positive_node_pairs(self):
"""`positive_node_pairs` is a representation of positive graphs used for
evaluating or computing loss in link prediction tasks.
- If `positive_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `positive_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
minibatch = DGLMiniBatch(
blocks=self._to_dgl_blocks(),
node_features=self.node_features,
edge_features=self.edge_features,
labels=self.labels,
)
# Need input nodes to fetch feature.
if self.node_features is None:
minibatch.input_nodes = self.input_nodes
# Need output nodes to fetch label.
if self.labels is None:
minibatch.output_nodes = self.seed_nodes
assert (
minibatch.blocks is not None
), "Sampled subgraphs for computation are missing."

# For link prediction tasks.
if self.compacted_node_pairs is not None:
minibatch.positive_node_pairs = self.compacted_node_pairs
# Build negative graph.
if (
self.compacted_negative_srcs is not None
and self.compacted_negative_dsts is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
minibatch.negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_negative_dsts.view(-1),
return self.compacted_node_pairs

@property
def negative_node_pairs(self):
"""`negative_node_pairs` is a representation of negative graphs used for
evaluating or computing loss in link prediction tasks.
- If `negative_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `negative_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
# Build negative graph.
if (
self.compacted_negative_srcs is not None
and self.compacted_negative_dsts is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_negative_dsts.view(-1),
)
# For heterogeneous graph.
else:
negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_negative_dsts[etype].view(-1),
)
# For heterogeneous graph.
else:
minibatch.negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_negative_dsts[etype].view(-1),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif self.compacted_negative_srcs is not None:
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_ratio = self.compacted_negative_srcs.size(1)
minibatch.negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_node_pairs[1].repeat_interleave(
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif (
self.compacted_negative_srcs is not None
and self.compacted_node_pairs is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_ratio = self.compacted_negative_srcs.size(1)
negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_node_pairs[1].repeat_interleave(
negative_ratio
),
)
# For heterogeneous graph.
else:
negative_ratio = list(self.compacted_negative_srcs.values())[
0
].size(1)
negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_node_pairs[etype][1].repeat_interleave(
negative_ratio
),
)
# For heterogeneous graph.
else:
negative_ratio = list(
self.compacted_negative_srcs.values()
)[0].size(1)
minibatch.negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_node_pairs[etype][
1
].repeat_interleave(negative_ratio),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif self.compacted_negative_dsts is not None:
# For homogeneous graph.
if isinstance(self.compacted_negative_dsts, torch.Tensor):
negative_ratio = self.compacted_negative_dsts.size(1)
minibatch.negative_node_pairs = (
self.compacted_node_pairs[0].repeat_interleave(
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif (
self.compacted_negative_dsts is not None
and self.compacted_node_pairs is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_dsts, torch.Tensor):
negative_ratio = self.compacted_negative_dsts.size(1)
negative_node_pairs = (
self.compacted_node_pairs[0].repeat_interleave(
negative_ratio
),
self.compacted_negative_dsts.view(-1),
)
# For heterogeneous graph.
else:
negative_ratio = list(self.compacted_negative_dsts.values())[
0
].size(1)
negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][0].repeat_interleave(
negative_ratio
),
self.compacted_negative_dsts.view(-1),
neg_dst.view(-1),
)
# For heterogeneous graph.
else:
negative_ratio = list(
self.compacted_negative_dsts.values()
)[0].size(1)
minibatch.negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][
0
].repeat_interleave(negative_ratio),
neg_dst.view(-1),
)
for etype, neg_dst in self.compacted_negative_dsts.items()
}
return minibatch
for etype, neg_dst in self.compacted_negative_dsts.items()
}
else:
negative_node_pairs = None
return negative_node_pairs

@property
def node_pairs_with_labels(self):
"""Get a node pair tensor and a label tensor from MiniBatch. They are
used for evaluating or computing loss. It will return
`(node_pairs, labels)` as result.
- If it's a link prediction task, `node_pairs` will contain both
negative and positive node pairs and `labels` will consist of 0 and 1,
indicating whether the corresponding node pair is negative or positive.
- If it's an edge classification task, this function will directly
return `compacted_node_pairs` and corresponding `labels`.
- Otherwise it will return None.
"""
if self.labels is None:
positive_node_pairs = self.positive_node_pairs
negative_node_pairs = self.negative_node_pairs
if positive_node_pairs is None or negative_node_pairs is None:
return None
pos_src, pos_dst = positive_node_pairs
neg_src, neg_dst = negative_node_pairs
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels.float())
else:
return (self.compacted_node_pairs, self.labels)

def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection."""
Expand All @@ -561,13 +599,16 @@ def _to(x, device):
for attr in dir(self):
# Only copy member variables.
if not callable(getattr(self, attr)) and not attr.startswith("__"):
setattr(
self,
attr,
recursive_apply(
getattr(self, attr), lambda x: _to(x, device)
),
)
try:
setattr(
self,
attr,
recursive_apply(
getattr(self, attr), lambda x: _to(x, device)
),
)
except AttributeError:
continue

return self

Expand Down
20 changes: 0 additions & 20 deletions python/dgl/graphbolt/minibatch_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

__all__ = [
"MiniBatchTransformer",
"DGLMiniBatchConverter",
]


Expand Down Expand Up @@ -41,22 +40,3 @@ def _transformer(self, minibatch):
minibatch, (MiniBatch, DGLMiniBatch)
), "The transformer output should be an instance of MiniBatch"
return minibatch


@functional_datapipe("to_dgl")
class DGLMiniBatchConverter(Mapper):
"""Convert a graphbolt mini-batch to a dgl mini-batch.
Functional name: :obj:`to_dgl`.
Parameters
----------
datapipe : DataPipe
The datapipe.
"""

def __init__(
self,
datapipe,
):
super().__init__(datapipe, MiniBatch.to_dgl)
Loading

0 comments on commit cad7cae

Please sign in to comment.