diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index a6177cc79942..4d2ad7f33ba5 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -50,6 +50,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { public: using NodeTypeToIDMap = torch::Dict; using EdgeTypeToIDMap = torch::Dict; + using NodeAttrMap = torch::Dict; using EdgeAttrMap = torch::Dict; /** @brief Default constructor. */ FusedCSCSamplingGraph() = default; @@ -66,16 +67,18 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * present. * @param edge_type_to_id A dictionary mapping edge type names to type IDs, if * present. + * @param node_attributes A dictionary of node attributes, if present. * @param edge_attributes A dictionary of edge attributes, if present. * */ FusedCSCSamplingGraph( const torch::Tensor& indptr, const torch::Tensor& indices, - const torch::optional& node_type_offset, - const torch::optional& type_per_edge, - const torch::optional& node_type_to_id, - const torch::optional& edge_type_to_id, - const torch::optional& edge_attributes); + const torch::optional& node_type_offset = torch::nullopt, + const torch::optional& type_per_edge = torch::nullopt, + const torch::optional& node_type_to_id = torch::nullopt, + const torch::optional& edge_type_to_id = torch::nullopt, + const torch::optional& node_attributes = torch::nullopt, + const torch::optional& edge_attributes = torch::nullopt); /** * @brief Create a fused CSC graph from tensors of CSC format. @@ -89,6 +92,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * present. * @param edge_type_to_id A dictionary mapping edge type names to type IDs, if * present. + * @param node_attributes A dictionary of node attributes, if present. * @param edge_attributes A dictionary of edge attributes, if present. * * @return FusedCSCSamplingGraph @@ -99,6 +103,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { const torch::optional& type_per_edge, const torch::optional& node_type_to_id, const torch::optional& edge_type_to_id, + const torch::optional& node_attributes, const torch::optional& edge_attributes); /** @brief Get the number of nodes. */ @@ -139,6 +144,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { return edge_type_to_id_; } + /** @brief Get the node attributes dictionary. */ + inline const torch::optional NodeAttributes() const { + return node_attributes_; + } + /** @brief Get the edge attributes dictionary. */ inline const torch::optional EdgeAttributes() const { return edge_attributes_; @@ -180,6 +190,12 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { edge_type_to_id_ = edge_type_to_id; } + /** @brief Set the node attributes dictionary. */ + inline void SetNodeAttributes( + const torch::optional& node_attributes) { + node_attributes_ = node_attributes; + } + /** @brief Set the edge attributes dictionary. */ inline void SetEdgeAttributes( const torch::optional& edge_attributes) { @@ -367,6 +383,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { */ torch::optional edge_type_to_id_; + /** + * @brief A dictionary of node attributes. Each key represents the attribute's + * name, while the corresponding value holds the attribute's specific value. + * The length of each value should match the total number of nodes." + */ + torch::optional node_attributes_; + /** * @brief A dictionary of edge attributes. Each key represents the attribute's * name, while the corresponding value holds the attribute's specific value. diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 8b13f35eaf24..637779aefb0a 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -56,6 +56,7 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph( const torch::optional& type_per_edge, const torch::optional& node_type_to_id, const torch::optional& edge_type_to_id, + const torch::optional& node_attributes, const torch::optional& edge_attributes) : indptr_(indptr), indices_(indices), @@ -63,6 +64,7 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph( type_per_edge_(type_per_edge), node_type_to_id_(node_type_to_id), edge_type_to_id_(edge_type_to_id), + node_attributes_(node_attributes), edge_attributes_(edge_attributes) { TORCH_CHECK(indptr.dim() == 1); TORCH_CHECK(indices.dim() == 1); @@ -75,6 +77,7 @@ c10::intrusive_ptr FusedCSCSamplingGraph::Create( const torch::optional& type_per_edge, const torch::optional& node_type_to_id, const torch::optional& edge_type_to_id, + const torch::optional& node_attributes, const torch::optional& edge_attributes) { if (node_type_offset.has_value()) { auto& offset = node_type_offset.value(); @@ -89,6 +92,11 @@ c10::intrusive_ptr FusedCSCSamplingGraph::Create( TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0)); TORCH_CHECK(edge_type_to_id.has_value()); } + if (node_attributes.has_value()) { + for (const auto& pair : node_attributes.value()) { + TORCH_CHECK(pair.value().size(0) == indptr.size(0) - 1); + } + } if (edge_attributes.has_value()) { for (const auto& pair : edge_attributes.value()) { TORCH_CHECK(pair.value().size(0) == indices.size(0)); @@ -96,7 +104,7 @@ c10::intrusive_ptr FusedCSCSamplingGraph::Create( } return c10::make_intrusive( indptr, indices, node_type_offset, type_per_edge, node_type_to_id, - edge_type_to_id, edge_attributes); + edge_type_to_id, node_attributes, edge_attributes); } void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { @@ -150,6 +158,25 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { edge_type_to_id_ = std::move(edge_type_to_id); } + // Optional node attributes. + torch::IValue has_node_attributes; + if (archive.try_read( + "FusedCSCSamplingGraph/has_node_attributes", has_node_attributes) && + has_node_attributes.toBool()) { + torch::Dict generic_dict = + read_from_archive(archive, "FusedCSCSamplingGraph/node_attributes") + .toGenericDict(); + NodeAttrMap target_dict; + for (const auto& pair : generic_dict) { + std::string key = pair.key().toStringRef(); + torch::Tensor value = pair.value().toTensor(); + // Use move to avoid copy. + target_dict.insert(std::move(key), std::move(value)); + } + // Same as above. + node_attributes_ = std::move(target_dict); + } + // Optional edge attributes. torch::IValue has_edge_attributes; if (archive.try_read( @@ -203,6 +230,13 @@ void FusedCSCSamplingGraph::Save( archive.write( "FusedCSCSamplingGraph/edge_type_to_id", edge_type_to_id_.value()); } + archive.write( + "FusedCSCSamplingGraph/has_node_attributes", + node_attributes_.has_value()); + if (node_attributes_) { + archive.write( + "FusedCSCSamplingGraph/node_attributes", node_attributes_.value()); + } archive.write( "FusedCSCSamplingGraph/has_edge_attributes", edge_attributes_.has_value()); @@ -238,6 +272,9 @@ void FusedCSCSamplingGraph::SetState( if (state.find("edge_type_to_id") != state.end()) { edge_type_to_id_ = DetensorizeDict(state.at("edge_type_to_id")); } + if (state.find("node_attributes") != state.end()) { + node_attributes_ = state.at("node_attributes"); + } if (state.find("edge_attributes") != state.end()) { edge_attributes_ = state.at("edge_attributes"); } @@ -268,6 +305,9 @@ FusedCSCSamplingGraph::GetState() const { if (edge_type_to_id_.has_value()) { state.insert("edge_type_to_id", TensorizeDict(edge_type_to_id_).value()); } + if (node_attributes_.has_value()) { + state.insert("node_attributes", node_attributes_.value()); + } if (edge_attributes_.has_value()) { state.insert("edge_attributes", edge_attributes_.value()); } @@ -596,10 +636,11 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) { auto type_per_edge = helper.ReadTorchTensor(); auto node_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict()); auto edge_type_to_id = DetensorizeDict(helper.ReadTorchTensorDict()); + auto node_attributes = helper.ReadTorchTensorDict(); auto edge_attributes = helper.ReadTorchTensorDict(); auto graph = c10::make_intrusive( indptr.value(), indices.value(), node_type_offset, type_per_edge, - node_type_to_id, edge_type_to_id, edge_attributes); + node_type_to_id, edge_type_to_id, node_attributes, edge_attributes); auto shared_memory = helper.ReleaseSharedMemory(); graph->HoldSharedMemoryObject( std::move(shared_memory.first), std::move(shared_memory.second)); @@ -616,6 +657,7 @@ FusedCSCSamplingGraph::CopyToSharedMemory( helper.WriteTorchTensor(type_per_edge_); helper.WriteTorchTensorDict(TensorizeDict(node_type_to_id_)); helper.WriteTorchTensorDict(TensorizeDict(edge_type_to_id_)); + helper.WriteTorchTensorDict(node_attributes_); helper.WriteTorchTensorDict(edge_attributes_); helper.Flush(); return BuildGraphFromSharedMemoryHelper(std::move(helper)); diff --git a/graphbolt/src/index_select.cc b/graphbolt/src/index_select.cc index 80068e8185c8..9b6f9fcf61aa 100644 --- a/graphbolt/src/index_select.cc +++ b/graphbolt/src/index_select.cc @@ -43,12 +43,7 @@ std::tuple IndexSelectCSC( TORCH_CHECK( c10::isIntegralType(indices.scalar_type(), false), "IndexSelectCSC is not implemented to slice noninteger types yet."); - torch::optional temp; - torch::optional temp2; - torch::optional temp3; - torch::optional temp4; - sampling::FusedCSCSamplingGraph g( - indptr, indices, temp, temp, temp2, temp3, temp4); + sampling::FusedCSCSamplingGraph g(indptr, indices); const auto res = g.InSubgraph(nodes); return std::make_tuple(res->indptr, res->indices); } diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index f5a7e503a901..a7dc057af8ee 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -37,6 +37,7 @@ TORCH_LIBRARY(graphbolt, m) { .def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge) .def("node_type_to_id", &FusedCSCSamplingGraph::NodeTypeToID) .def("edge_type_to_id", &FusedCSCSamplingGraph::EdgeTypeToID) + .def("node_attributes", &FusedCSCSamplingGraph::NodeAttributes) .def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes) .def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr) .def("set_indices", &FusedCSCSamplingGraph::SetIndices) @@ -44,6 +45,7 @@ TORCH_LIBRARY(graphbolt, m) { .def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge) .def("set_node_type_to_id", &FusedCSCSamplingGraph::SetNodeTypeToID) .def("set_edge_type_to_id", &FusedCSCSamplingGraph::SetEdgeTypeToID) + .def("set_node_attributes", &FusedCSCSamplingGraph::SetNodeAttributes) .def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes) .def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph) .def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors) diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index ddb37308e081..7a3d0d9fbb53 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -279,6 +279,27 @@ def edge_type_to_id( """Sets the edge type to id dictionary if present.""" self._c_csc_graph.set_edge_type_to_id(edge_type_to_id) + @property + def node_attributes(self) -> Optional[Dict[str, torch.Tensor]]: + """Returns the node attributes dictionary. + + Returns + ------- + Dict[str, torch.Tensor] or None + If present, returns a dictionary of node attributes. Each key + represents the attribute's name, while the corresponding value + holds the attribute's specific value. The length of each value + should match the total number of nodes." + """ + return self._c_csc_graph.node_attributes() + + @node_attributes.setter + def node_attributes( + self, node_attributes: Optional[Dict[str, torch.Tensor]] + ) -> None: + """Sets the node attributes dictionary.""" + self._c_csc_graph.set_node_attributes(node_attributes) + @property def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]: """Returns the edge attributes dictionary. @@ -892,6 +913,9 @@ def _to(x, device): self.type_per_edge = recursive_apply( self.type_per_edge, lambda x: _to(x, device) ) + self.node_attributes = recursive_apply( + self.node_attributes, lambda x: _to(x, device) + ) self.edge_attributes = recursive_apply( self.edge_attributes, lambda x: _to(x, device) ) @@ -906,6 +930,7 @@ def fused_csc_sampling_graph( type_per_edge: Optional[torch.tensor] = None, node_type_to_id: Optional[Dict[str, int]] = None, edge_type_to_id: Optional[Dict[str, int]] = None, + node_attributes: Optional[Dict[str, torch.tensor]] = None, edge_attributes: Optional[Dict[str, torch.tensor]] = None, ) -> FusedCSCSamplingGraph: """Create a FusedCSCSamplingGraph object from a CSC representation. @@ -926,6 +951,8 @@ def fused_csc_sampling_graph( Map node types to ids, by default None. edge_type_to_id : Optional[Dict[str, int]], optional Map edge types to ids, by default None. + node_attributes: Optional[Dict[str, torch.tensor]], optional + Node attributes of the graph, by default None. edge_attributes: Optional[Dict[str, torch.tensor]], optional Edge attributes of the graph, by default None. @@ -946,7 +973,7 @@ def fused_csc_sampling_graph( ... node_type_offset=node_type_offset, ... type_per_edge=type_per_edge, ... node_type_to_id=ntypes, edge_type_to_id=etypes, - ... edge_attributes=None,) + ... node_attributes=None, edge_attributes=None,) >>> print(graph) FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]), indices=tensor([1, 3, 0, 1, 2, 0, 3]), @@ -997,6 +1024,7 @@ def fused_csc_sampling_graph( type_per_edge, node_type_to_id, edge_type_to_id, + node_attributes, edge_attributes, ), ) @@ -1037,6 +1065,8 @@ def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str: meta_str += f", node_type_to_id={graph.node_type_to_id}" if graph.edge_type_to_id is not None: meta_str += f", edge_type_to_id={graph.edge_type_to_id}" + if graph.node_attributes is not None: + meta_str += f", node_attributes={graph.node_attributes}" if graph.edge_attributes is not None: meta_str += f", edge_attributes={graph.edge_attributes}" @@ -1094,6 +1124,8 @@ def from_dglgraph( # Assign edge type according to the order of CSC matrix. type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids] + node_attributes = {} + edge_attributes = {} if include_original_edge_id: # Assign edge attributes according to the original eids mapping. @@ -1107,6 +1139,7 @@ def from_dglgraph( type_per_edge, node_type_to_id, edge_type_to_id, + node_attributes, edge_attributes, ), ) diff --git a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py index 39e5b491a03d..8e65cad70ff2 100644 --- a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py +++ b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py @@ -126,12 +126,19 @@ def test_homo_graph(total_num_nodes, total_num_edges): csc_indptr, indices = gbt.random_homo_graph( total_num_nodes, total_num_edges ) + node_attributes = { + "A1": torch.arange(total_num_nodes), + "A2": torch.arange(total_num_nodes), + } edge_attributes = { "A1": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges), } graph = gb.fused_csc_sampling_graph( - csc_indptr, indices, edge_attributes=edge_attributes + csc_indptr, + indices, + node_attributes=node_attributes, + edge_attributes=edge_attributes, ) assert graph.total_num_nodes == total_num_nodes @@ -140,6 +147,7 @@ def test_homo_graph(total_num_nodes, total_num_edges): assert torch.equal(csc_indptr, graph.csc_indptr) assert torch.equal(indices, graph.indices) + assert graph.node_attributes == node_attributes assert graph.edge_attributes == edge_attributes assert graph.node_type_offset is None assert graph.type_per_edge is None @@ -167,6 +175,10 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): ) = gbt.random_hetero_graph( total_num_nodes, total_num_edges, num_ntypes, num_etypes ) + node_attributes = { + "A1": torch.arange(total_num_nodes), + "A2": torch.arange(total_num_nodes), + } edge_attributes = { "A1": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges), @@ -178,6 +190,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): type_per_edge=type_per_edge, node_type_to_id=node_type_to_id, edge_type_to_id=edge_type_to_id, + node_attributes=node_attributes, edge_attributes=edge_attributes, ) @@ -188,6 +201,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): assert torch.equal(indices, graph.indices) assert torch.equal(node_type_offset, graph.node_type_offset) assert torch.equal(type_per_edge, graph.type_per_edge) + assert graph.node_attributes == node_attributes assert graph.edge_attributes == edge_attributes assert node_type_to_id == graph.node_type_to_id assert edge_type_to_id == graph.edge_type_to_id @@ -327,11 +341,32 @@ def test_node_type_offset_wrong_legnth(node_type_offset): "total_num_nodes, total_num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)], ) -def test_load_save_homo_graph(total_num_nodes, total_num_edges): +@pytest.mark.parametrize("has_node_attrs", [True, False]) +@pytest.mark.parametrize("has_edge_attrs", [True, False]) +def test_load_save_homo_graph( + total_num_nodes, total_num_edges, has_node_attrs, has_edge_attrs +): csc_indptr, indices = gbt.random_homo_graph( total_num_nodes, total_num_edges ) - graph = gb.fused_csc_sampling_graph(csc_indptr, indices) + node_attributes = None + if has_node_attrs: + node_attributes = { + "A": torch.arange(total_num_nodes), + "B": torch.arange(total_num_nodes), + } + edge_attributes = None + if has_edge_attrs: + edge_attributes = { + "A": torch.arange(total_num_edges), + "B": torch.arange(total_num_edges), + } + graph = gb.fused_csc_sampling_graph( + csc_indptr, + indices, + node_attributes=node_attributes, + edge_attributes=edge_attributes, + ) with tempfile.TemporaryDirectory() as test_dir: filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt") @@ -348,7 +383,22 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges): assert graph.type_per_edge is None and graph2.type_per_edge is None assert graph.node_type_to_id is None and graph2.node_type_to_id is None assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None - assert graph.edge_attributes is None and graph2.edge_attributes is None + if has_node_attrs: + assert graph.node_attributes.keys() == graph2.node_attributes.keys() + for key in graph.node_attributes.keys(): + assert torch.equal( + graph.node_attributes[key], graph2.node_attributes[key] + ) + else: + assert graph.node_attributes is None and graph2.node_attributes is None + if has_edge_attrs: + assert graph.edge_attributes.keys() == graph2.edge_attributes.keys() + for key in graph.edge_attributes.keys(): + assert torch.equal( + graph.edge_attributes[key], graph2.edge_attributes[key] + ) + else: + assert graph.edge_attributes is None and graph2.edge_attributes is None @unittest.skipIf( @@ -360,8 +410,15 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges): [(1, 1), (100, 1), (10, 50), (1000, 50000)], ) @pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)]) +@pytest.mark.parametrize("has_node_attrs", [True, False]) +@pytest.mark.parametrize("has_edge_attrs", [True, False]) def test_load_save_hetero_graph( - total_num_nodes, total_num_edges, num_ntypes, num_etypes + total_num_nodes, + total_num_edges, + num_ntypes, + num_etypes, + has_node_attrs, + has_edge_attrs, ): ( csc_indptr, @@ -373,6 +430,18 @@ def test_load_save_hetero_graph( ) = gbt.random_hetero_graph( total_num_nodes, total_num_edges, num_ntypes, num_etypes ) + node_attributes = None + if has_node_attrs: + node_attributes = { + "A": torch.arange(total_num_nodes), + "B": torch.arange(total_num_nodes), + } + edge_attributes = None + if has_edge_attrs: + edge_attributes = { + "A": torch.arange(total_num_edges), + "B": torch.arange(total_num_edges), + } graph = gb.fused_csc_sampling_graph( csc_indptr, indices, @@ -380,6 +449,8 @@ def test_load_save_hetero_graph( type_per_edge=type_per_edge, node_type_to_id=node_type_to_id, edge_type_to_id=edge_type_to_id, + node_attributes=node_attributes, + edge_attributes=edge_attributes, ) with tempfile.TemporaryDirectory() as test_dir: @@ -396,6 +467,22 @@ def test_load_save_hetero_graph( assert torch.equal(graph.type_per_edge, graph2.type_per_edge) assert graph.node_type_to_id == graph2.node_type_to_id assert graph.edge_type_to_id == graph2.edge_type_to_id + if has_node_attrs: + assert graph.node_attributes.keys() == graph2.node_attributes.keys() + for key in graph.node_attributes.keys(): + assert torch.equal( + graph.node_attributes[key], graph2.node_attributes[key] + ) + else: + assert graph.node_attributes is None and graph2.node_attributes is None + if has_edge_attrs: + assert graph.edge_attributes.keys() == graph2.edge_attributes.keys() + for key in graph.edge_attributes.keys(): + assert torch.equal( + graph.edge_attributes[key], graph2.edge_attributes[key] + ) + else: + assert graph.edge_attributes is None and graph2.edge_attributes is None @unittest.skipIf( @@ -406,11 +493,32 @@ def test_load_save_hetero_graph( "total_num_nodes, total_num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)], ) -def test_pickle_homo_graph(total_num_nodes, total_num_edges): +@pytest.mark.parametrize("has_node_attrs", [True, False]) +@pytest.mark.parametrize("has_edge_attrs", [True, False]) +def test_pickle_homo_graph( + total_num_nodes, total_num_edges, has_node_attrs, has_edge_attrs +): csc_indptr, indices = gbt.random_homo_graph( total_num_nodes, total_num_edges ) - graph = gb.fused_csc_sampling_graph(csc_indptr, indices) + node_attributes = None + if has_node_attrs: + node_attributes = { + "A": torch.arange(total_num_nodes), + "B": torch.arange(total_num_nodes), + } + edge_attributes = None + if has_edge_attrs: + edge_attributes = { + "A": torch.arange(total_num_edges), + "B": torch.arange(total_num_edges), + } + graph = gb.fused_csc_sampling_graph( + csc_indptr, + indices, + node_attributes=node_attributes, + edge_attributes=edge_attributes, + ) serialized = pickle.dumps(graph) graph2 = pickle.loads(serialized) @@ -425,7 +533,22 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges): assert graph.type_per_edge is None and graph2.type_per_edge is None assert graph.node_type_to_id is None and graph2.node_type_to_id is None assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None - assert graph.edge_attributes is None and graph2.edge_attributes is None + if has_node_attrs: + assert graph.node_attributes.keys() == graph2.node_attributes.keys() + for key in graph.node_attributes.keys(): + assert torch.equal( + graph.node_attributes[key], graph2.node_attributes[key] + ) + else: + assert graph.node_attributes is None and graph2.node_attributes is None + if has_edge_attrs: + assert graph.edge_attributes.keys() == graph2.edge_attributes.keys() + for key in graph.edge_attributes.keys(): + assert torch.equal( + graph.edge_attributes[key], graph2.edge_attributes[key] + ) + else: + assert graph.edge_attributes is None and graph2.edge_attributes is None @unittest.skipIf( @@ -437,8 +560,15 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges): [(1, 1), (100, 1), (10, 50), (1000, 50000)], ) @pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)]) +@pytest.mark.parametrize("has_node_attrs", [True, False]) +@pytest.mark.parametrize("has_edge_attrs", [True, False]) def test_pickle_hetero_graph( - total_num_nodes, total_num_edges, num_ntypes, num_etypes + total_num_nodes, + total_num_edges, + num_ntypes, + num_etypes, + has_node_attrs, + has_edge_attrs, ): ( csc_indptr, @@ -450,10 +580,18 @@ def test_pickle_hetero_graph( ) = gbt.random_hetero_graph( total_num_nodes, total_num_edges, num_ntypes, num_etypes ) - edge_attributes = { - "a": torch.randn((total_num_edges,)), - "b": torch.randint(1, 10, (total_num_edges,)), - } + node_attributes = None + if has_node_attrs: + node_attributes = { + "A": torch.arange(total_num_nodes), + "B": torch.arange(total_num_nodes), + } + edge_attributes = None + if has_edge_attrs: + edge_attributes = { + "A": torch.arange(total_num_edges), + "B": torch.arange(total_num_edges), + } graph = gb.fused_csc_sampling_graph( csc_indptr, indices, @@ -461,6 +599,7 @@ def test_pickle_hetero_graph( type_per_edge=type_per_edge, node_type_to_id=node_type_to_id, edge_type_to_id=edge_type_to_id, + node_attributes=node_attributes, edge_attributes=edge_attributes, ) @@ -480,9 +619,22 @@ def test_pickle_hetero_graph( assert graph.edge_type_to_id.keys() == graph2.edge_type_to_id.keys() for i in graph.edge_type_to_id.keys(): assert graph.edge_type_to_id[i] == graph2.edge_type_to_id[i] - assert graph.edge_attributes.keys() == graph2.edge_attributes.keys() - for i in graph.edge_attributes.keys(): - assert torch.equal(graph.edge_attributes[i], graph2.edge_attributes[i]) + if has_node_attrs: + assert graph.node_attributes.keys() == graph2.node_attributes.keys() + for key in graph.node_attributes.keys(): + assert torch.equal( + graph.node_attributes[key], graph2.node_attributes[key] + ) + else: + assert graph.node_attributes is None and graph2.node_attributes is None + if has_edge_attrs: + assert graph.edge_attributes.keys() == graph2.edge_attributes.keys() + for key in graph.edge_attributes.keys(): + assert torch.equal( + graph.edge_attributes[key], graph2.edge_attributes[key] + ) + else: + assert graph.edge_attributes is None and graph2.edge_attributes is None def process_csc_sampling_graph_multiprocessing(graph): @@ -1258,6 +1410,18 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor): t1[:] = old_t1 +def check_node_edge_attributes(graph1, graph2, attributes, attr_name): + for name, attr in attributes.items(): + edge_attributes_1 = getattr(graph1, attr_name) + edge_attributes_2 = getattr(graph2, attr_name) + assert name in edge_attributes_1 + assert name in edge_attributes_2 + assert torch.equal(edge_attributes_1[name], attr) + check_tensors_on_the_same_shared_memory( + edge_attributes_1[name], edge_attributes_2[name] + ) + + @unittest.skipIf( F._default_context_str == "gpu", reason="FusedCSCSamplingGraph is only supported on CPU.", @@ -1266,22 +1430,31 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor): "total_num_nodes, total_num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)], ) +@pytest.mark.parametrize("test_node_attrs", [True, False]) @pytest.mark.parametrize("test_edge_attrs", [True, False]) def test_homo_graph_on_shared_memory( - total_num_nodes, total_num_edges, test_edge_attrs + total_num_nodes, total_num_edges, test_node_attrs, test_edge_attrs ): csc_indptr, indices = gbt.random_homo_graph( total_num_nodes, total_num_edges ) + node_attributes = None + if test_node_attrs: + node_attributes = { + "A1": torch.arange(total_num_nodes), + "A2": torch.arange(total_num_nodes), + } + edge_attributes = None if test_edge_attrs: edge_attributes = { "A1": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges), } - else: - edge_attributes = None graph = gb.fused_csc_sampling_graph( - csc_indptr, indices, edge_attributes=edge_attributes + csc_indptr, + indices, + node_attributes=node_attributes, + edge_attributes=edge_attributes, ) shm_name = "test_homo_g" @@ -1307,14 +1480,14 @@ def test_homo_graph_on_shared_memory( ) check_tensors_on_the_same_shared_memory(graph1.indices, graph2.indices) + if test_node_attrs: + check_node_edge_attributes( + graph1, graph2, node_attributes, "node_attributes" + ) if test_edge_attrs: - for name, edge_attr in edge_attributes.items(): - assert name in graph1.edge_attributes - assert name in graph2.edge_attributes - assert torch.equal(graph1.edge_attributes[name], edge_attr) - check_tensors_on_the_same_shared_memory( - graph1.edge_attributes[name], graph2.edge_attributes[name] - ) + check_node_edge_attributes( + graph1, graph2, edge_attributes, "edge_attributes" + ) assert graph1.node_type_offset is None and graph2.node_type_offset is None assert graph1.type_per_edge is None and graph2.type_per_edge is None @@ -1333,9 +1506,15 @@ def test_homo_graph_on_shared_memory( @pytest.mark.parametrize( "num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1), (1000, 1000)] ) +@pytest.mark.parametrize("test_node_attrs", [True, False]) @pytest.mark.parametrize("test_edge_attrs", [True, False]) def test_hetero_graph_on_shared_memory( - total_num_nodes, total_num_edges, num_ntypes, num_etypes, test_edge_attrs + total_num_nodes, + total_num_edges, + num_ntypes, + num_etypes, + test_node_attrs, + test_edge_attrs, ): ( csc_indptr, @@ -1348,13 +1527,20 @@ def test_hetero_graph_on_shared_memory( total_num_nodes, total_num_edges, num_ntypes, num_etypes ) + node_attributes = None + if test_node_attrs: + node_attributes = { + "A1": torch.arange(total_num_nodes), + "A2": torch.arange(total_num_nodes), + } + + edge_attributes = None if test_edge_attrs: edge_attributes = { "A1": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges), } - else: - edge_attributes = None + graph = gb.fused_csc_sampling_graph( csc_indptr, indices, @@ -1362,6 +1548,7 @@ def test_hetero_graph_on_shared_memory( type_per_edge=type_per_edge, node_type_to_id=node_type_to_id, edge_type_to_id=edge_type_to_id, + node_attributes=node_attributes, edge_attributes=edge_attributes, ) @@ -1398,14 +1585,14 @@ def test_hetero_graph_on_shared_memory( graph1.type_per_edge, graph2.type_per_edge ) + if test_node_attrs: + check_node_edge_attributes( + graph1, graph2, node_attributes, "node_attributes" + ) if test_edge_attrs: - for name, edge_attr in edge_attributes.items(): - assert name in graph1.edge_attributes - assert name in graph2.edge_attributes - assert torch.equal(graph1.edge_attributes[name], edge_attr) - check_tensors_on_the_same_shared_memory( - graph1.edge_attributes[name], graph2.edge_attributes[name] - ) + check_node_edge_attributes( + graph1, graph2, edge_attributes, "edge_attributes" + ) assert node_type_to_id == graph1.node_type_to_id assert edge_type_to_id == graph1.edge_type_to_id