Skip to content

Commit

Permalink
[GraphBolt] Update __repr__ of FusedCSCSamplingGraph (#6956)
Browse files Browse the repository at this point in the history
  • Loading branch information
Skeleton003 authored Jan 26, 2024
1 parent 2da6ace commit fe78093
Showing 1 changed file with 50 additions and 54 deletions.
104 changes: 50 additions & 54 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""CSC format sampling graph."""

import textwrap

# pylint: disable= invalid-name
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -26,7 +29,38 @@ class FusedCSCSamplingGraph(SamplingGraph):
r"""A sampling graph in CSC format."""

def __repr__(self):
return _csc_sampling_graph_str(self)
final_str = (
"{classname}(csc_indptr={csc_indptr},\n"
"indices={indices},\n"
"{metadata})"
)

classname_str = self.__class__.__name__
csc_indptr_str = str(self.csc_indptr)
indices_str = str(self.indices)
meta_str = f"total_num_nodes={self.total_num_nodes}, num_edges={self.num_edges},"
if self.node_type_offset is not None:
meta_str += f"\nnode_type_offset={self.node_type_offset},"
if self.type_per_edge is not None:
meta_str += f"\ntype_per_edge={self.type_per_edge},"
if self.node_type_to_id is not None:
meta_str += f"\nnode_type_to_id={self.node_type_to_id},"
if self.edge_type_to_id is not None:
meta_str += f"\nedge_type_to_id={self.edge_type_to_id},"
if self.node_attributes is not None:
meta_str += f"\nnode_attributes={self.node_attributes},"
if self.edge_attributes is not None:
meta_str += f"\nedge_attributes={self.edge_attributes},"

final_str = final_str.format(
classname=classname_str,
csc_indptr=csc_indptr_str,
indices=indices_str,
metadata=meta_str,
)
return textwrap.indent(
final_str, " " * (len(classname_str) + 1)
).strip()

def __init__(
self,
Expand Down Expand Up @@ -1120,19 +1154,23 @@ def fused_csc_sampling_graph(
--------
>>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
>>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}
>>> csc_indptr = torch.tensor([0, 2, 5, 7])
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
>>> node_type_offset = torch.tensor([0, 1, 2, 3])
>>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
>>> csc_indptr = torch.tensor([0, 2, 5, 7, 8])
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3, 2])
>>> node_type_offset = torch.tensor([0, 1, 2, 4])
>>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0, 0])
>>> graph = graphbolt.fused_csc_sampling_graph(csc_indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, edge_type_to_id=etypes,
... node_attributes=None, edge_attributes=None,)
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, edge_type_to_id=etypes,
... node_attributes=None, edge_attributes=None,)
>>> print(graph)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
total_num_nodes=3, total_num_edges=7)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7, 8]),
indices=tensor([1, 3, 0, 1, 2, 0, 3, 2]),
total_num_nodes=4, num_edges={'n1:e1:n2': 5, 'n1:e2:n3': 3},
node_type_offset=tensor([0, 1, 2, 4]),
type_per_edge=tensor([0, 1, 0, 1, 1, 0, 0, 0]),
node_type_to_id={'n1': 0, 'n2': 1, 'n3': 2},
edge_type_to_id={'n1:e1:n2': 0, 'n1:e2:n3': 1},)
"""
if node_type_to_id is not None and edge_type_to_id is not None:
node_types = list(node_type_to_id.keys())
Expand Down Expand Up @@ -1205,48 +1243,6 @@ def load_from_shared_memory(
)


def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
"""Internal function for converting a csc sampling graph to string
representation.
"""
csc_indptr_str = str(graph.csc_indptr)
indices_str = str(graph.indices)
meta_str = f"num_nodes={graph.total_num_nodes}, num_edges={graph.num_edges}"
if graph.node_type_offset is not None:
meta_str += f", node_type_offset={graph.node_type_offset}"
if graph.type_per_edge is not None:
meta_str += f", type_per_edge={graph.type_per_edge}"
if graph.node_type_to_id is not None:
meta_str += f", node_type_to_id={graph.node_type_to_id}"
if graph.edge_type_to_id is not None:
meta_str += f", edge_type_to_id={graph.edge_type_to_id}"
if graph.node_attributes is not None:
meta_str += f", node_attributes={graph.node_attributes}"
if graph.edge_attributes is not None:
meta_str += f", edge_attributes={graph.edge_attributes}"

prefix = f"{type(graph).__name__}("

def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)

final_str = (
"csc_indptr="
+ _add_indent(csc_indptr_str, len("csc_indptr="))
+ ",\n"
+ "indices="
+ _add_indent(indices_str, len("indices="))
+ ",\n"
+ meta_str
+ ")"
)

final_str = prefix + _add_indent(final_str, len(prefix))
return final_str


def from_dglgraph(
g: DGLGraph,
is_homogeneous: bool = False,
Expand Down

0 comments on commit fe78093

Please sign in to comment.