diff --git a/python/phenolrs/numpy_loader.py b/python/phenolrs/numpy_loader.py index 0e089ff..e05ad6c 100644 --- a/python/phenolrs/numpy_loader.py +++ b/python/phenolrs/numpy_loader.py @@ -83,7 +83,7 @@ def load_graph_to_numpy( for e_col_name, entries in metagraph["edgeCollections"].items() ] - features_by_col, coo_map, col_to_key_inds = graph_to_pyg_format( + features_by_col, coo_map, col_to_adb_id_to_ind = graph_to_pyg_format( { "database": database, "vertex_collections": vertex_collections, @@ -92,4 +92,9 @@ def load_graph_to_numpy( } ) - return features_by_col, coo_map, col_to_key_inds, vertex_cols_source_to_output + return ( + features_by_col, + coo_map, + col_to_adb_id_to_ind, + vertex_cols_source_to_output, + ) diff --git a/python/phenolrs/pyg_loader.py b/python/phenolrs/pyg_loader.py index de03774..3e81d53 100644 --- a/python/phenolrs/pyg_loader.py +++ b/python/phenolrs/pyg_loader.py @@ -20,7 +20,7 @@ def load_into_pyg_data( tls_cert: typing.Any | None = None, parallelism: int | None = None, batch_size: int | None = None, - ) -> Data: + ) -> tuple[Data, dict[str, dict[str, int]]]: if "vertexCollections" not in metagraph: raise PhenolError("vertexCollections not found in metagraph") if "edgeCollections" not in metagraph: @@ -46,7 +46,7 @@ def load_into_pyg_data( ( features_by_col, coo_map, - col_to_key_inds, + col_to_adb_id_to_ind, vertex_cols_source_to_output, ) = NumpyLoader.load_graph_to_numpy( database, @@ -85,7 +85,7 @@ def load_into_pyg_data( if result.numel() > 0: data["edge_index"] = result - return data + return data, col_to_adb_id_to_ind @staticmethod def load_into_pyg_heterodata( @@ -98,7 +98,7 @@ def load_into_pyg_heterodata( tls_cert: typing.Any | None = None, parallelism: int | None = None, batch_size: int | None = None, - ) -> HeteroData: + ) -> tuple[HeteroData, dict[str, dict[str, int]]]: if "vertexCollections" not in metagraph: raise PhenolError("vertexCollections not found in metagraph") if "edgeCollections" not in metagraph: @@ -112,7 +112,7 @@ def load_into_pyg_heterodata( ( features_by_col, coo_map, - col_to_key_inds, + col_to_adb_id_to_ind, vertex_cols_source_to_output, ) = NumpyLoader.load_graph_to_numpy( database, @@ -142,4 +142,4 @@ def load_into_pyg_heterodata( if result.numel() > 0: data[(from_name, edge_col_name, to_name)].edge_index = result - return data + return data, col_to_adb_id_to_ind diff --git a/python/tests/test_all.py b/python/tests/test_all.py index 7852b60..5401605 100644 --- a/python/tests/test_all.py +++ b/python/tests/test_all.py @@ -19,9 +19,12 @@ def test_phenol_abide_hetero( username=connection_information["username"], password=connection_information["password"], ) - assert isinstance(result, HeteroData) - assert result["Subjects"]["x"].shape == (871, 2000) + data, col_to_adb_id_to_ind = result + assert isinstance(data, HeteroData) + assert data["Subjects"]["x"].shape == (871, 2000) + assert len(col_to_adb_id_to_ind["Subjects"]) == 871 + # Metagraph variation result = PygLoader.load_into_pyg_heterodata( connection_information["dbName"], { @@ -34,8 +37,11 @@ def test_phenol_abide_hetero( username=connection_information["username"], password=connection_information["password"], ) - assert isinstance(result, HeteroData) - assert result["Subjects"]["x"].shape == (871, 2000) + + data, col_to_adb_id_to_ind = result + assert isinstance(data, HeteroData) + assert data["Subjects"]["x"].shape == (871, 2000) + assert len(col_to_adb_id_to_ind["Subjects"]) == 871 def test_phenol_abide_numpy( @@ -44,7 +50,7 @@ def test_phenol_abide_numpy( ( features_by_col, coo_map, - col_to_key_inds, + col_to_adb_id_to_ind, vertex_cols_source_to_output, ) = NumpyLoader.load_graph_to_numpy( connection_information["dbName"], @@ -62,13 +68,13 @@ def test_phenol_abide_numpy( 2, 606770, ) - assert len(col_to_key_inds["Subjects"]) == 871 + assert len(col_to_adb_id_to_ind["Subjects"]) == 871 assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}} ( features_by_col, coo_map, - col_to_key_inds, + col_to_adb_id_to_ind, vertex_cols_source_to_output, ) = NumpyLoader.load_graph_to_numpy( connection_information["dbName"], @@ -83,5 +89,5 @@ def test_phenol_abide_numpy( assert features_by_col["Subjects"]["brain_fmri_features"].shape == (871, 2000) assert len(coo_map) == 0 - assert len(col_to_key_inds["Subjects"]) == 871 + assert len(col_to_adb_id_to_ind["Subjects"]) == 871 assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}}