Skip to content

Commit

Permalink
Merge pull request #13 from arangoml/MLP-642
Browse files Browse the repository at this point in the history
MLP-642 | return index mapping
  • Loading branch information
Alex Geenen authored May 14, 2024
2 parents c9e6189 + cc77561 commit 6c5f11c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
9 changes: 7 additions & 2 deletions python/phenolrs/numpy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def load_graph_to_numpy(
for e_col_name, entries in metagraph["edgeCollections"].items()
]

features_by_col, coo_map, col_to_key_inds = graph_to_pyg_format(
features_by_col, coo_map, col_to_adb_id_to_ind = graph_to_pyg_format(
{
"database": database,
"vertex_collections": vertex_collections,
Expand All @@ -92,4 +92,9 @@ def load_graph_to_numpy(
}
)

return features_by_col, coo_map, col_to_key_inds, vertex_cols_source_to_output
return (
features_by_col,
coo_map,
col_to_adb_id_to_ind,
vertex_cols_source_to_output,
)
12 changes: 6 additions & 6 deletions python/phenolrs/pyg_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def load_into_pyg_data(
tls_cert: typing.Any | None = None,
parallelism: int | None = None,
batch_size: int | None = None,
) -> Data:
) -> tuple[Data, dict[str, dict[str, int]]]:
if "vertexCollections" not in metagraph:
raise PhenolError("vertexCollections not found in metagraph")
if "edgeCollections" not in metagraph:
Expand All @@ -46,7 +46,7 @@ def load_into_pyg_data(
(
features_by_col,
coo_map,
col_to_key_inds,
col_to_adb_id_to_ind,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
database,
Expand Down Expand Up @@ -85,7 +85,7 @@ def load_into_pyg_data(
if result.numel() > 0:
data["edge_index"] = result

return data
return data, col_to_adb_id_to_ind

@staticmethod
def load_into_pyg_heterodata(
Expand All @@ -98,7 +98,7 @@ def load_into_pyg_heterodata(
tls_cert: typing.Any | None = None,
parallelism: int | None = None,
batch_size: int | None = None,
) -> HeteroData:
) -> tuple[HeteroData, dict[str, dict[str, int]]]:
if "vertexCollections" not in metagraph:
raise PhenolError("vertexCollections not found in metagraph")
if "edgeCollections" not in metagraph:
Expand All @@ -112,7 +112,7 @@ def load_into_pyg_heterodata(
(
features_by_col,
coo_map,
col_to_key_inds,
col_to_adb_id_to_ind,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
database,
Expand Down Expand Up @@ -142,4 +142,4 @@ def load_into_pyg_heterodata(
if result.numel() > 0:
data[(from_name, edge_col_name, to_name)].edge_index = result

return data
return data, col_to_adb_id_to_ind
22 changes: 14 additions & 8 deletions python/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ def test_phenol_abide_hetero(
username=connection_information["username"],
password=connection_information["password"],
)
assert isinstance(result, HeteroData)
assert result["Subjects"]["x"].shape == (871, 2000)
data, col_to_adb_id_to_ind = result
assert isinstance(data, HeteroData)
assert data["Subjects"]["x"].shape == (871, 2000)
assert len(col_to_adb_id_to_ind["Subjects"]) == 871

# Metagraph variation
result = PygLoader.load_into_pyg_heterodata(
connection_information["dbName"],
{
Expand All @@ -34,8 +37,11 @@ def test_phenol_abide_hetero(
username=connection_information["username"],
password=connection_information["password"],
)
assert isinstance(result, HeteroData)
assert result["Subjects"]["x"].shape == (871, 2000)

data, col_to_adb_id_to_ind = result
assert isinstance(data, HeteroData)
assert data["Subjects"]["x"].shape == (871, 2000)
assert len(col_to_adb_id_to_ind["Subjects"]) == 871


def test_phenol_abide_numpy(
Expand All @@ -44,7 +50,7 @@ def test_phenol_abide_numpy(
(
features_by_col,
coo_map,
col_to_key_inds,
col_to_adb_id_to_ind,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
connection_information["dbName"],
Expand All @@ -62,13 +68,13 @@ def test_phenol_abide_numpy(
2,
606770,
)
assert len(col_to_key_inds["Subjects"]) == 871
assert len(col_to_adb_id_to_ind["Subjects"]) == 871
assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}}

(
features_by_col,
coo_map,
col_to_key_inds,
col_to_adb_id_to_ind,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
connection_information["dbName"],
Expand All @@ -83,5 +89,5 @@ def test_phenol_abide_numpy(

assert features_by_col["Subjects"]["brain_fmri_features"].shape == (871, 2000)
assert len(coo_map) == 0
assert len(col_to_key_inds["Subjects"]) == 871
assert len(col_to_adb_id_to_ind["Subjects"]) == 871
assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}}

0 comments on commit 6c5f11c

Please sign in to comment.