From cc77561c5f5c3a7eb600eafdcc9b119f64d55603 Mon Sep 17 00:00:00 2001
From: Anthony Mahanna <anthony.mahanna@arangodb.com>
Date: Mon, 13 May 2024 12:37:44 -0400
Subject: [PATCH] MLP-642 | initial commit

---
 python/phenolrs/numpy_loader.py |  9 +++++++--
 python/phenolrs/pyg_loader.py   | 12 ++++++------
 python/tests/test_all.py        | 22 ++++++++++++++--------
 3 files changed, 27 insertions(+), 16 deletions(-)

diff --git a/python/phenolrs/numpy_loader.py b/python/phenolrs/numpy_loader.py
index 0e089ff..e05ad6c 100644
--- a/python/phenolrs/numpy_loader.py
+++ b/python/phenolrs/numpy_loader.py
@@ -83,7 +83,7 @@ def load_graph_to_numpy(
                 for e_col_name, entries in metagraph["edgeCollections"].items()
             ]
 
-        features_by_col, coo_map, col_to_key_inds = graph_to_pyg_format(
+        features_by_col, coo_map, col_to_adb_id_to_ind = graph_to_pyg_format(
             {
                 "database": database,
                 "vertex_collections": vertex_collections,
@@ -92,4 +92,9 @@ def load_graph_to_numpy(
             }
         )
 
-        return features_by_col, coo_map, col_to_key_inds, vertex_cols_source_to_output
+        return (
+            features_by_col,
+            coo_map,
+            col_to_adb_id_to_ind,
+            vertex_cols_source_to_output,
+        )
diff --git a/python/phenolrs/pyg_loader.py b/python/phenolrs/pyg_loader.py
index de03774..3e81d53 100644
--- a/python/phenolrs/pyg_loader.py
+++ b/python/phenolrs/pyg_loader.py
@@ -20,7 +20,7 @@ def load_into_pyg_data(
         tls_cert: typing.Any | None = None,
         parallelism: int | None = None,
         batch_size: int | None = None,
-    ) -> Data:
+    ) -> tuple[Data, dict[str, dict[str, int]]]:
         if "vertexCollections" not in metagraph:
             raise PhenolError("vertexCollections not found in metagraph")
         if "edgeCollections" not in metagraph:
@@ -46,7 +46,7 @@ def load_into_pyg_data(
         (
             features_by_col,
             coo_map,
-            col_to_key_inds,
+            col_to_adb_id_to_ind,
             vertex_cols_source_to_output,
         ) = NumpyLoader.load_graph_to_numpy(
             database,
@@ -85,7 +85,7 @@ def load_into_pyg_data(
                 if result.numel() > 0:
                     data["edge_index"] = result
 
-        return data
+        return data, col_to_adb_id_to_ind
 
     @staticmethod
     def load_into_pyg_heterodata(
@@ -98,7 +98,7 @@ def load_into_pyg_heterodata(
         tls_cert: typing.Any | None = None,
         parallelism: int | None = None,
         batch_size: int | None = None,
-    ) -> HeteroData:
+    ) -> tuple[HeteroData, dict[str, dict[str, int]]]:
         if "vertexCollections" not in metagraph:
             raise PhenolError("vertexCollections not found in metagraph")
         if "edgeCollections" not in metagraph:
@@ -112,7 +112,7 @@ def load_into_pyg_heterodata(
         (
             features_by_col,
             coo_map,
-            col_to_key_inds,
+            col_to_adb_id_to_ind,
             vertex_cols_source_to_output,
         ) = NumpyLoader.load_graph_to_numpy(
             database,
@@ -142,4 +142,4 @@ def load_into_pyg_heterodata(
             if result.numel() > 0:
                 data[(from_name, edge_col_name, to_name)].edge_index = result
 
-        return data
+        return data, col_to_adb_id_to_ind
diff --git a/python/tests/test_all.py b/python/tests/test_all.py
index 7852b60..5401605 100644
--- a/python/tests/test_all.py
+++ b/python/tests/test_all.py
@@ -19,9 +19,12 @@ def test_phenol_abide_hetero(
         username=connection_information["username"],
         password=connection_information["password"],
     )
-    assert isinstance(result, HeteroData)
-    assert result["Subjects"]["x"].shape == (871, 2000)
+    data, col_to_adb_id_to_ind = result
+    assert isinstance(data, HeteroData)
+    assert data["Subjects"]["x"].shape == (871, 2000)
+    assert len(col_to_adb_id_to_ind["Subjects"]) == 871
 
+    # Metagraph variation
     result = PygLoader.load_into_pyg_heterodata(
         connection_information["dbName"],
         {
@@ -34,8 +37,11 @@ def test_phenol_abide_hetero(
         username=connection_information["username"],
         password=connection_information["password"],
     )
-    assert isinstance(result, HeteroData)
-    assert result["Subjects"]["x"].shape == (871, 2000)
+
+    data, col_to_adb_id_to_ind = result
+    assert isinstance(data, HeteroData)
+    assert data["Subjects"]["x"].shape == (871, 2000)
+    assert len(col_to_adb_id_to_ind["Subjects"]) == 871
 
 
 def test_phenol_abide_numpy(
@@ -44,7 +50,7 @@ def test_phenol_abide_numpy(
     (
         features_by_col,
         coo_map,
-        col_to_key_inds,
+        col_to_adb_id_to_ind,
         vertex_cols_source_to_output,
     ) = NumpyLoader.load_graph_to_numpy(
         connection_information["dbName"],
@@ -62,13 +68,13 @@ def test_phenol_abide_numpy(
         2,
         606770,
     )
-    assert len(col_to_key_inds["Subjects"]) == 871
+    assert len(col_to_adb_id_to_ind["Subjects"]) == 871
     assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}}
 
     (
         features_by_col,
         coo_map,
-        col_to_key_inds,
+        col_to_adb_id_to_ind,
         vertex_cols_source_to_output,
     ) = NumpyLoader.load_graph_to_numpy(
         connection_information["dbName"],
@@ -83,5 +89,5 @@ def test_phenol_abide_numpy(
 
     assert features_by_col["Subjects"]["brain_fmri_features"].shape == (871, 2000)
     assert len(coo_map) == 0
-    assert len(col_to_key_inds["Subjects"]) == 871
+    assert len(col_to_adb_id_to_ind["Subjects"]) == 871
     assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}}