From 7e46c5f9b25c1c0c4866b556f5e16383b8eb08e1 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Sun, 28 Apr 2024 22:17:30 -0400 Subject: [PATCH 01/27] wip: nxadb-to-nxcg using the adapter for now... --- .../algorithms/centrality/betweenness.py | 13 ++- nx_arangodb/classes/graph.py | 82 +++++++++++++++++++ nx_arangodb/convert.py | 69 +++++++++++++++- 3 files changed, 154 insertions(+), 10 deletions(-) diff --git a/nx_arangodb/algorithms/centrality/betweenness.py b/nx_arangodb/algorithms/centrality/betweenness.py index c8bd0f1b..8b8b945b 100644 --- a/nx_arangodb/algorithms/centrality/betweenness.py +++ b/nx_arangodb/algorithms/centrality/betweenness.py @@ -1,11 +1,10 @@ from networkx.algorithms.centrality import betweenness as nx_betweenness -from nx_arangodb.convert import _to_graph as _to_nx_arangodb_graph +from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph from nx_arangodb.utils import networkx_algorithm try: import pylibcugraph as plc - from nx_cugraph.convert import _to_graph as _to_nx_cugraph_graph from nx_cugraph.utils import _seed_to_int GPU_ENABLED = True @@ -15,9 +14,9 @@ __all__ = ["betweenness_centrality"] -# 1. If GPU is enabled, call nx-cugraph bc() after converting to a nx_cugraph graph (in-memory graph) -# 2. If GPU is not enabled, call networkx bc() after converting to a networkx graph (in-memory graph) -# 3. If GPU is not enabled, call networkx bc() **without** converting to a networkx graph (remote graph) +# 1. If GPU is enabled, call nx-cugraph bc() after converting to an ncxg graph (in-memory graph) +# 2. If GPU is not enabled, call networkx bc() after converting to an nxadb graph (in-memory graph) +# 3. If GPU is not enabled, call networkx bc() **without** converting to a nxadb graph (remote graph) @networkx_algorithm( @@ -41,7 +40,7 @@ def betweenness_centrality( ) seed = _seed_to_int(seed) - G = _to_nx_cugraph_graph(G, weight) + G = _to_nxcg_graph(G, weight) node_ids, values = plc.betweenness_centrality( resource_handle=plc.ResourceHandle(), graph=G._get_plc_graph(), @@ -58,7 +57,7 @@ def betweenness_centrality( else: print("ANTHONY: GPU is disabled. Using nx bc()") - G = _to_nx_arangodb_graph(G) + G = _to_nxadb_graph(G) betweenness = dict.fromkeys(G, 0.0) # b[v]=0 for v in G if k is None: diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index db495e36..5b5ebfa6 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -1,6 +1,9 @@ +import os from typing import ClassVar import networkx as nx +from arango import ArangoClient +from arango.database import StandardDatabase import nx_arangodb as nxadb @@ -16,3 +19,82 @@ class Graph(nx.Graph): @classmethod def to_networkx_class(cls) -> type[nx.Graph]: return nx.Graph + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.set_db() + + self.__graph_exists = False + if self.__db is not None: + self.set_graph_name() + + @property + def db(self) -> StandardDatabase: + if self.__db is None: + raise ValueError("Database not set") + + return self.__db + + @property + def graph_name(self) -> str: + if self.__graph_name is None: + raise ValueError("Graph name not set") + + return self.__graph_name + + @property + def graph_exists(self) -> bool: + return self.__graph_exists + + def set_db(self, db: StandardDatabase | None = None): + if db is not None: + if not isinstance(db, StandardDatabase): + raise TypeError( + "**db** must be an instance of arango.database.StandardDatabase" + ) + + self.__db = db + return + + host = os.getenv("DATABASE_HOST") + username = os.getenv("DATABASE_USERNAME") + password = os.getenv("DATABASE_PASSWORD") + db_name = os.getenv("DATABASE_NAME") + + # TODO: Raise a custom exception if any of the environment + # variables are missing. For now, we'll just set db to None. + if not all([host, username, password, db_name]): + self.__db = None + return + + self.__db = ArangoClient(host=host, request_timeout=None).db( + db_name, username, password, verify=True + ) + + def set_graph_name(self, graph_name: str | None = None): + if self.__db is None: + raise ValueError("Cannot set graph name without setting the database first") + + self.__graph_name = os.getenv("DATABASE_GRAPH_NAME") + if graph_name is not None: + if not isinstance(graph_name, str): + raise TypeError("**graph_name** must be a string") + + self.__graph_name = graph_name + + if self.__graph_name is None: + self.graph_exists = False + print("DATABASE_GRAPH_NAME environment variable not set") + + elif not self.db.has_graph(self.__graph_name): + self.graph_exists = False + print(f"Graph '{self.__graph_name}' does not exist in the database") + + else: + self.graph_exists = True + print(f"Found graph '{self.__graph_name}' in the database") diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index fae54ce4..9b9664f0 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -167,7 +167,7 @@ def to_networkx(G: nxadb.Graph, *, sort_edges: bool = False) -> nx.Graph: return G.to_networkx_class()(incoming_graph_data=G) -def _to_graph( +def _to_nxadb_graph( G, edge_attr: AttrKey | None = None, edge_default: EdgeValue | None = 1, @@ -188,7 +188,7 @@ def _to_graph( raise TypeError -def _to_directed_graph( +def _to_nxadb_directed_graph( G, edge_attr: AttrKey | None = None, edge_default: EdgeValue | None = 1, @@ -214,7 +214,7 @@ def _to_directed_graph( raise TypeError -def _to_undirected_graph( +def _to_nxadb_undirected_graph( G, edge_attr: AttrKey | None = None, edge_default: EdgeValue | None = 1, @@ -235,3 +235,66 @@ def _to_undirected_graph( ) # TODO: handle cugraph.Graph raise TypeError + + +try: + import nx_cugraph as nxcg + from adbnx_adapter import ADBNX_Adapter + + def _to_nxcg_graph( + G, + edge_attr: AttrKey | None = None, + edge_default: EdgeValue | None = 1, + edge_dtype: Dtype | None = None, + ) -> nxcg.Graph | nxcg.DiGraph: + """Ensure that input type is a nx_cugraph graph, and convert if necessary. + + Directed and undirected graphs are both allowed. + This is an internal utility function and may change or be removed. + """ + if isinstance(G, nxcg.Graph): + return G + if isinstance(G, nxadb.Graph): + # Assumption: G.adb_graph_name points to an existing graph in ArangoDB + # Therefore, the user wants us to pull the graph from ArangoDB, + # and convert it to an nx_cugraph graph. + # We currently accomplish this by using the NetworkX adapter for ArangoDB, + # which converts the ArangoDB graph to a NetworkX graph, and then we convert + # the NetworkX graph to an nx_cugraph graph. + # TODO: Implement a direct conversion from ArangoDB to nx_cugraph + if G.graph_exists: + adapter = ADBNX_Adapter(G.db) + nx_g = adapter.arangodb_graph_to_networkx( + G.graph_name, G.to_networkx_class()() + ) + + return nxcg.convert.from_networkx( + nx_g, + {edge_attr: edge_default} if edge_attr is not None else None, + edge_dtype, + ) + + # If G is a networkx graph, or is a nxadb graph that doesn't point to an "existing" + # ArangoDB graph, then we just treat it as a normal networkx graph & + # convert it to nx_cugraph. + # TODO: Need to revisit the "existing" ArangoDB graph condition... + if isinstance(G, nx.Graph): + return nxcg.convert.from_networkx( + G, + {edge_attr: edge_default} if edge_attr is not None else None, + edge_dtype, + ) + + # TODO: handle cugraph.Graph + raise TypeError + +except ModuleNotFoundError: + + def _to_nxcg_graph( + G, + edge_attr: AttrKey | None = None, + edge_default: EdgeValue | None = 1, + edge_dtype: Dtype | None = None, + ) -> nxadb.Graph: + m = "nx-cugraph is not installed; cannot convert to nx-cugraph graph" + raise NotImplementedError(m) From 5cd2afd247c0ae14295410c184ba18f46b743ba7 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Mon, 29 Apr 2024 11:28:32 -0400 Subject: [PATCH 02/27] fix: typo --- nx_arangodb/classes/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 5b5ebfa6..5dff376c 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -72,7 +72,7 @@ def set_db(self, db: StandardDatabase | None = None): self.__db = None return - self.__db = ArangoClient(host=host, request_timeout=None).db( + self.__db = ArangoClient(hosts=host, request_timeout=None).db( db_name, username, password, verify=True ) From 0b4b1cfb3fde54ad1f81b27d2024c99dd0408ad5 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Mon, 29 Apr 2024 11:30:47 -0400 Subject: [PATCH 03/27] attempt fix: graph classes --- nx_arangodb/classes/digraph.py | 13 +++++++++++++ nx_arangodb/classes/multidigraph.py | 13 +++++++++++++ nx_arangodb/classes/multigraph.py | 13 +++++++++++++ 3 files changed, 39 insertions(+) diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index a3add415..cd7dce39 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -12,3 +12,16 @@ class DiGraph(nx.DiGraph, Graph): @classmethod def to_networkx_class(cls) -> type[nx.DiGraph]: return nx.DiGraph + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.set_db() + + self.__graph_exists = False + if self.__db is not None: + self.set_graph_name() diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py index 3b83aea5..8337a46d 100644 --- a/nx_arangodb/classes/multidigraph.py +++ b/nx_arangodb/classes/multidigraph.py @@ -13,3 +13,16 @@ class MultiDiGraph(nx.MultiDiGraph, MultiGraph, DiGraph): @classmethod def to_networkx_class(cls) -> type[nx.MultiDiGraph]: return nx.MultiDiGraph + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.set_db() + + self.__graph_exists = False + if self.__db is not None: + self.set_graph_name() diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index 6afe0c63..b93aa2f4 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -12,3 +12,16 @@ class MultiGraph(nx.MultiGraph, Graph): @classmethod def to_networkx_class(cls) -> type[nx.MultiGraph]: return nx.MultiGraph + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.set_db() + + self.__graph_exists = False + if self.__db is not None: + self.set_graph_name() From e254390b0e56cc535cbab7ed7efe0d40dd851ab7 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Mon, 29 Apr 2024 11:36:42 -0400 Subject: [PATCH 04/27] fix: graph classes (again) --- nx_arangodb/classes/digraph.py | 6 ++++-- nx_arangodb/classes/multidigraph.py | 6 ++++-- nx_arangodb/classes/multigraph.py | 6 ++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index cd7dce39..2f37e567 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -20,8 +20,10 @@ def __init__( ): super().__init__(*args, **kwargs) - self.set_db() - + self.__db = None + self.__graph_name = None self.__graph_exists = False + + self.set_db() if self.__db is not None: self.set_graph_name() diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py index 8337a46d..a2933da4 100644 --- a/nx_arangodb/classes/multidigraph.py +++ b/nx_arangodb/classes/multidigraph.py @@ -21,8 +21,10 @@ def __init__( ): super().__init__(*args, **kwargs) - self.set_db() - + self.__db = None + self.__graph_name = None self.__graph_exists = False + + self.set_db() if self.__db is not None: self.set_graph_name() diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index b93aa2f4..83f8e5e3 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -20,8 +20,10 @@ def __init__( ): super().__init__(*args, **kwargs) - self.set_db() - + self.__db = None + self.__graph_name = None self.__graph_exists = False + + self.set_db() if self.__db is not None: self.set_graph_name() From c76124a7cb04930e0792c0b499954cc0aef1a585 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Mon, 29 Apr 2024 11:47:05 -0400 Subject: [PATCH 05/27] fix: typo --- nx_arangodb/classes/graph.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 5dff376c..d2deace8 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -27,9 +27,11 @@ def __init__( ): super().__init__(*args, **kwargs) - self.set_db() - + self.__db = None + self.__graph_name = None self.__graph_exists = False + + self.set_db() if self.__db is not None: self.set_graph_name() @@ -88,13 +90,13 @@ def set_graph_name(self, graph_name: str | None = None): self.__graph_name = graph_name if self.__graph_name is None: - self.graph_exists = False + self.__graph_exists = False print("DATABASE_GRAPH_NAME environment variable not set") elif not self.db.has_graph(self.__graph_name): - self.graph_exists = False + self.__graph_exists = False print(f"Graph '{self.__graph_name}' does not exist in the database") else: - self.graph_exists = True + self.__graph_exists = True print(f"Found graph '{self.__graph_name}' in the database") From eed930add21213709d16b7db2944cbf92e9c4063 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Mon, 29 Apr 2024 12:02:00 -0400 Subject: [PATCH 06/27] add DiGraph property not sure what's going on.. --- nx_arangodb/classes/digraph.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 2f37e567..36401d12 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -27,3 +27,7 @@ def __init__( self.set_db() if self.__db is not None: self.set_graph_name() + + @property + def graph_exists(self) -> bool: + return self.__graph_exists \ No newline at end of file From adefe323b9f57e2e99232233f8a70ee50a017d1c Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Tue, 30 Apr 2024 22:15:29 -0400 Subject: [PATCH 07/27] nxadb-to-nxcg (rust) | initial commit --- nx_arangodb/convert.py | 75 ++++++++++++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 18 deletions(-) diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 9b9664f0..37563fdf 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -3,21 +3,14 @@ from __future__ import annotations import itertools -import operator as op -from collections import Counter -from collections.abc import Mapping from typing import TYPE_CHECKING -# import cupy as cp import networkx as nx -import numpy as np import nx_arangodb as nxadb -from .utils import index_dtype - if TYPE_CHECKING: # pragma: no cover - from nx_arangodb.typing import AttrKey, Dtype, EdgeValue, NodeValue, any_ndarray + from nx_arangodb.typing import AttrKey, Dtype, EdgeValue, NodeValue __all__ = [ "from_networkx", @@ -238,8 +231,12 @@ def _to_nxadb_undirected_graph( try: + import os + + import cupy as cp + import numpy as np import nx_cugraph as nxcg - from adbnx_adapter import ADBNX_Adapter + from phenolrs.coo_loader import CooLoader def _to_nxcg_graph( G, @@ -263,13 +260,8 @@ def _to_nxcg_graph( # the NetworkX graph to an nx_cugraph graph. # TODO: Implement a direct conversion from ArangoDB to nx_cugraph if G.graph_exists: - adapter = ADBNX_Adapter(G.db) - nx_g = adapter.arangodb_graph_to_networkx( - G.graph_name, G.to_networkx_class()() - ) - - return nxcg.convert.from_networkx( - nx_g, + return _from_networkx_arangodb( + G, {edge_attr: edge_default} if edge_attr is not None else None, edge_dtype, ) @@ -288,9 +280,56 @@ def _to_nxcg_graph( # TODO: handle cugraph.Graph raise TypeError -except ModuleNotFoundError: + def _from_networkx_arangodb( + G: nxadb.Graph, as_directed: bool = False + ) -> nxcg.Graph | nxcg.DiGraph: + if G.is_multigraph(): + raise NotImplementedError("Multigraphs not yet supported") + + adb_graph = G.db.graph(G.graph_name) + + v_cols = adb_graph.vertex_collections() + edge_definitions = adb_graph.edge_definitions() + e_cols = {c["edge_collection"] for c in edge_definitions} + + metagraph = { + "vertexCollections": {col: {} for col in v_cols}, + "edgeCollections": {col: {} for col in e_cols}, + } + + src_indices, dst_indices, vertex_ids = CooLoader.load_coo( + G.db.name, + metagraph, + [os.environ["DATABASE_HOST"]], + username=os.environ["DATABASE_USERNAME"], + password=os.environ["DATABASE_PASSWORD"], + # parallelism=, + # batch_size= + ) - def _to_nxcg_graph( + src_indices = cp.array(src_indices) + dst_indices = cp.array(dst_indices) + + N = len(vertex_ids) + + if G.is_directed() or as_directed: + klass = nxcg.DiGraph + else: + klass = nxcg.Graph + + rv = klass.from_coo( + N, + src_indices, + dst_indices, + key_to_id={k: i for i, k in enumerate(vertex_ids)}, + ) + + return rv + +except ModuleNotFoundError as e: + print(f"ANTHONY: {e}") + + def _from_networkx_arangodb( G, edge_attr: AttrKey | None = None, edge_default: EdgeValue | None = 1, From b52b63c269c6488bca9666c0c94186b8cc0edaf0 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Tue, 30 Apr 2024 22:28:16 -0400 Subject: [PATCH 08/27] print statements --- nx_arangodb/convert.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 37563fdf..4ce77137 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -232,6 +232,7 @@ def _to_nxadb_undirected_graph( try: import os + import time import cupy as cp import numpy as np @@ -260,6 +261,7 @@ def _to_nxcg_graph( # the NetworkX graph to an nx_cugraph graph. # TODO: Implement a direct conversion from ArangoDB to nx_cugraph if G.graph_exists: + print("ANTHONY: Graph exists! Running _from_networkx_arangodb()") return _from_networkx_arangodb( G, {edge_attr: edge_default} if edge_attr is not None else None, @@ -286,6 +288,7 @@ def _from_networkx_arangodb( if G.is_multigraph(): raise NotImplementedError("Multigraphs not yet supported") + print("ANTHONY: Building metagraph...") adb_graph = G.db.graph(G.graph_name) v_cols = adb_graph.vertex_collections() @@ -297,6 +300,10 @@ def _from_networkx_arangodb( "edgeCollections": {col: {} for col in e_cols}, } + print("ANTHONY: Running COO Loader...") + + start_time = time.time() + src_indices, dst_indices, vertex_ids = CooLoader.load_coo( G.db.name, metagraph, @@ -307,6 +314,12 @@ def _from_networkx_arangodb( # batch_size= ) + end_time = time.time() + + print("ANTHONY: COO Load took:", end_time - start_time) + + print("ANTHONY: Converting to cupy arrays...") + src_indices = cp.array(src_indices) dst_indices = cp.array(dst_indices) @@ -317,11 +330,15 @@ def _from_networkx_arangodb( else: klass = nxcg.Graph + print("ANTHONY: Running nx_cugraph.from_coo()...") + + key_to_id = {k: i for i, k in enumerate(vertex_ids)} + rv = klass.from_coo( N, src_indices, dst_indices, - key_to_id={k: i for i, k in enumerate(vertex_ids)}, + key_to_id=key_to_id, ) return rv From fb6e10a81565a44f7b8c191f10aa368bca36921a Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Tue, 30 Apr 2024 22:38:02 -0400 Subject: [PATCH 09/27] fix: function name --- nx_arangodb/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 4ce77137..f6e56380 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -346,7 +346,7 @@ def _from_networkx_arangodb( except ModuleNotFoundError as e: print(f"ANTHONY: {e}") - def _from_networkx_arangodb( + def _to_nxcg_graph( G, edge_attr: AttrKey | None = None, edge_default: EdgeValue | None = 1, From c00c08b7894bbc770ec34c33fe1ccb7fd68a39fe Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Tue, 30 Apr 2024 23:22:23 -0400 Subject: [PATCH 10/27] fix: `as_directed` --- nx_arangodb/convert.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index f6e56380..d0f403ad 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -244,6 +244,7 @@ def _to_nxcg_graph( edge_attr: AttrKey | None = None, edge_default: EdgeValue | None = 1, edge_dtype: Dtype | None = None, + as_directed: bool = False, ) -> nxcg.Graph | nxcg.DiGraph: """Ensure that input type is a nx_cugraph graph, and convert if necessary. @@ -262,11 +263,7 @@ def _to_nxcg_graph( # TODO: Implement a direct conversion from ArangoDB to nx_cugraph if G.graph_exists: print("ANTHONY: Graph exists! Running _from_networkx_arangodb()") - return _from_networkx_arangodb( - G, - {edge_attr: edge_default} if edge_attr is not None else None, - edge_dtype, - ) + return _from_networkx_arangodb(G, as_directed=as_directed) # If G is a networkx graph, or is a nxadb graph that doesn't point to an "existing" # ArangoDB graph, then we just treat it as a normal networkx graph & @@ -277,6 +274,7 @@ def _to_nxcg_graph( G, {edge_attr: edge_default} if edge_attr is not None else None, edge_dtype, + as_directed=as_directed, ) # TODO: handle cugraph.Graph @@ -351,6 +349,7 @@ def _to_nxcg_graph( edge_attr: AttrKey | None = None, edge_default: EdgeValue | None = 1, edge_dtype: Dtype | None = None, + as_directed: bool = False, ) -> nxadb.Graph: m = "nx-cugraph is not installed; cannot convert to nx-cugraph graph" raise NotImplementedError(m) From e9c67edb593669a38edcc80b7c55c91f13442aed Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 10:20:38 -0400 Subject: [PATCH 11/27] more print statements --- nx_arangodb/convert.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index d0f403ad..f75edfc8 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -286,7 +286,6 @@ def _from_networkx_arangodb( if G.is_multigraph(): raise NotImplementedError("Multigraphs not yet supported") - print("ANTHONY: Building metagraph...") adb_graph = G.db.graph(G.graph_name) v_cols = adb_graph.vertex_collections() @@ -298,8 +297,6 @@ def _from_networkx_arangodb( "edgeCollections": {col: {} for col in e_cols}, } - print("ANTHONY: Running COO Loader...") - start_time = time.time() src_indices, dst_indices, vertex_ids = CooLoader.load_coo( @@ -316,11 +313,15 @@ def _from_networkx_arangodb( print("ANTHONY: COO Load took:", end_time - start_time) - print("ANTHONY: Converting to cupy arrays...") + start_time = time.time() src_indices = cp.array(src_indices) dst_indices = cp.array(dst_indices) + end_time = time.time() + + print("ANTHONY: cupy arrays took:", end_time - start_time) + N = len(vertex_ids) if G.is_directed() or as_directed: @@ -328,16 +329,25 @@ def _from_networkx_arangodb( else: klass = nxcg.Graph - print("ANTHONY: Running nx_cugraph.from_coo()...") + start_time = time.time() key_to_id = {k: i for i, k in enumerate(vertex_ids)} + end_time = time.time() + + print("ANTHONY: key_to_id took:", end_time - start_time) + + start_time = time.time() + rv = klass.from_coo( N, src_indices, dst_indices, key_to_id=key_to_id, ) + end_time = time.time() + + print("ANTHONY: from_coo took:", end_time - start_time) return rv From 8a4b845631ddccd2c3185d075220c3e707501137 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 11:20:34 -0400 Subject: [PATCH 12/27] cleanup: `vertex_ids_to_index` --- nx_arangodb/convert.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index f75edfc8..0389294d 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -299,7 +299,7 @@ def _from_networkx_arangodb( start_time = time.time() - src_indices, dst_indices, vertex_ids = CooLoader.load_coo( + src_indices, dst_indices, vertex_ids_to_index = CooLoader.load_coo( G.db.name, metagraph, [os.environ["DATABASE_HOST"]], @@ -322,7 +322,7 @@ def _from_networkx_arangodb( print("ANTHONY: cupy arrays took:", end_time - start_time) - N = len(vertex_ids) + N = len(vertex_ids_to_index) if G.is_directed() or as_directed: klass = nxcg.DiGraph @@ -331,19 +331,11 @@ def _from_networkx_arangodb( start_time = time.time() - key_to_id = {k: i for i, k in enumerate(vertex_ids)} - - end_time = time.time() - - print("ANTHONY: key_to_id took:", end_time - start_time) - - start_time = time.time() - rv = klass.from_coo( N, src_indices, dst_indices, - key_to_id=key_to_id, + key_to_id=vertex_ids_to_index, ) end_time = time.time() From 1d4b8532795584f7ce3aa934aa1e2663a2cee955 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 11:28:57 -0400 Subject: [PATCH 13/27] new: `parallelism` & `batch_size` kwargs hacky for now... --- nx_arangodb/classes/graph.py | 3 +++ nx_arangodb/classes/multidigraph.py | 3 +++ nx_arangodb/classes/multigraph.py | 3 +++ nx_arangodb/convert.py | 9 +++++++-- 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index d2deace8..33862baf 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -31,6 +31,9 @@ def __init__( self.__graph_name = None self.__graph_exists = False + self.coo_load_parallelism = None + self.coo_load_batch_size = None + self.set_db() if self.__db is not None: self.set_graph_name() diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py index a2933da4..e2f86584 100644 --- a/nx_arangodb/classes/multidigraph.py +++ b/nx_arangodb/classes/multidigraph.py @@ -25,6 +25,9 @@ def __init__( self.__graph_name = None self.__graph_exists = False + self.coo_load_parallelism = None + self.coo_load_batch_size = None + self.set_db() if self.__db is not None: self.set_graph_name() diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index 83f8e5e3..0a252284 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -24,6 +24,9 @@ def __init__( self.__graph_name = None self.__graph_exists = False + self.coo_load_parallelism = None + self.coo_load_batch_size = None + self.set_db() if self.__db is not None: self.set_graph_name() diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 0389294d..870d0665 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -299,14 +299,19 @@ def _from_networkx_arangodb( start_time = time.time() + kwargs = {} + if G.coo_load_parallelism is not None: + kwargs["parallelism"] = G.coo_load_parallelism + if G.coo_load_batch_size is not None: + kwargs["batch_size"] = G.coo_load_batch_size + src_indices, dst_indices, vertex_ids_to_index = CooLoader.load_coo( G.db.name, metagraph, [os.environ["DATABASE_HOST"]], username=os.environ["DATABASE_USERNAME"], password=os.environ["DATABASE_PASSWORD"], - # parallelism=, - # batch_size= + **kwargs, ) end_time = time.time() From 124f049626027136bceaa19df58042aa5961fbf0 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 11:29:04 -0400 Subject: [PATCH 14/27] Update digraph.py --- nx_arangodb/classes/digraph.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 36401d12..4ad86480 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -24,10 +24,13 @@ def __init__( self.__graph_name = None self.__graph_exists = False + self.coo_load_parallelism = None + self.coo_load_batch_size = None + self.set_db() if self.__db is not None: self.set_graph_name() @property def graph_exists(self) -> bool: - return self.__graph_exists \ No newline at end of file + return self.__graph_exists From f196e705979c99758013085906d52141758cdfa2 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 14:09:15 -0400 Subject: [PATCH 15/27] new: cache coo --- nx_arangodb/classes/digraph.py | 4 ++ nx_arangodb/classes/graph.py | 9 ++++ nx_arangodb/convert.py | 77 ++++++++++++++++++++-------------- 3 files changed, 58 insertions(+), 32 deletions(-) diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 4ad86480..9c8cd1e5 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -24,8 +24,12 @@ def __init__( self.__graph_name = None self.__graph_exists = False + self.coo_use_cache = False self.coo_load_parallelism = None self.coo_load_batch_size = None + self.src_indices = None + self.dst_indices = None + self.vertex_ids_to_index = None self.set_db() if self.__db is not None: diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 33862baf..4a8f916e 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -31,8 +31,12 @@ def __init__( self.__graph_name = None self.__graph_exists = False + self.coo_use_cache = False self.coo_load_parallelism = None self.coo_load_batch_size = None + self.src_indices = None + self.dst_indices = None + self.vertex_ids_to_index = None self.set_db() if self.__db is not None: @@ -56,6 +60,11 @@ def graph_name(self) -> str: def graph_exists(self) -> bool: return self.__graph_exists + def clear_coo_cache(self): + self.src_indices = None + self.dst_indices = None + self.vertex_ids_to_index = None + def set_db(self, db: StandardDatabase | None = None): if db is not None: if not isinstance(db, StandardDatabase): diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 870d0665..50c37581 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -286,46 +286,59 @@ def _from_networkx_arangodb( if G.is_multigraph(): raise NotImplementedError("Multigraphs not yet supported") - adb_graph = G.db.graph(G.graph_name) + if G.coo_use_cache and all( + [G.src_indices, G.dst_indices, G.vertex_ids_to_index] + ): + src_indices = G.src_indices + dst_indices = G.dst_indices + vertex_ids_to_index = G.vertex_ids_to_index - v_cols = adb_graph.vertex_collections() - edge_definitions = adb_graph.edge_definitions() - e_cols = {c["edge_collection"] for c in edge_definitions} + else: + adb_graph = G.db.graph(G.graph_name) + + v_cols = adb_graph.vertex_collections() + edge_definitions = adb_graph.edge_definitions() + e_cols = {c["edge_collection"] for c in edge_definitions} + + metagraph = { + "vertexCollections": {col: {} for col in v_cols}, + "edgeCollections": {col: {} for col in e_cols}, + } + + start_time = time.time() + + kwargs = {} + if G.coo_load_parallelism is not None: + kwargs["parallelism"] = G.coo_load_parallelism + if G.coo_load_batch_size is not None: + kwargs["batch_size"] = G.coo_load_batch_size + + src_indices, dst_indices, vertex_ids_to_index = CooLoader.load_coo( + G.db.name, + metagraph, + [os.environ["DATABASE_HOST"]], + username=os.environ["DATABASE_USERNAME"], + password=os.environ["DATABASE_PASSWORD"], + **kwargs, + ) - metagraph = { - "vertexCollections": {col: {} for col in v_cols}, - "edgeCollections": {col: {} for col in e_cols}, - } + end_time = time.time() - start_time = time.time() + print("ANTHONY: COO Load took:", end_time - start_time) - kwargs = {} - if G.coo_load_parallelism is not None: - kwargs["parallelism"] = G.coo_load_parallelism - if G.coo_load_batch_size is not None: - kwargs["batch_size"] = G.coo_load_batch_size - - src_indices, dst_indices, vertex_ids_to_index = CooLoader.load_coo( - G.db.name, - metagraph, - [os.environ["DATABASE_HOST"]], - username=os.environ["DATABASE_USERNAME"], - password=os.environ["DATABASE_PASSWORD"], - **kwargs, - ) + start_time = time.time() - end_time = time.time() + src_indices = cp.array(src_indices) + dst_indices = cp.array(dst_indices) - print("ANTHONY: COO Load took:", end_time - start_time) + end_time = time.time() - start_time = time.time() - - src_indices = cp.array(src_indices) - dst_indices = cp.array(dst_indices) - - end_time = time.time() + print("ANTHONY: cupy arrays took:", end_time - start_time) - print("ANTHONY: cupy arrays took:", end_time - start_time) + if G.coo_use_cache: + G.src_indices = src_indices + G.dst_indices = dst_indices + G.vertex_ids_to_index = vertex_ids_to_index N = len(vertex_ids_to_index) From 2b16ed1f8665cc2177a0ebd64fbdf5799ea7cdab Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 14:19:39 -0400 Subject: [PATCH 16/27] cleanup --- nx_arangodb/__init__.py | 4 --- nx_arangodb/algorithms/__init__.py | 4 ++- .../algorithms/centrality/betweenness.py | 32 ++++++------------- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/nx_arangodb/__init__.py b/nx_arangodb/__init__.py index 4400a9f8..fcb0b7ac 100644 --- a/nx_arangodb/__init__.py +++ b/nx_arangodb/__init__.py @@ -10,10 +10,6 @@ from . import convert from .convert import * -# TODO Anthony: Do we need this? -# from . import convert_matrix -# from .convert_matrix import * - from . import algorithms from .algorithms import * diff --git a/nx_arangodb/algorithms/__init__.py b/nx_arangodb/algorithms/__init__.py index 570d5dd6..60bb5633 100644 --- a/nx_arangodb/algorithms/__init__.py +++ b/nx_arangodb/algorithms/__init__.py @@ -1,2 +1,4 @@ -from . import centrality +from . import centrality, community, link_analysis from .centrality import * +from .community import * +from .link_analysis import * diff --git a/nx_arangodb/algorithms/centrality/betweenness.py b/nx_arangodb/algorithms/centrality/betweenness.py index 8b8b945b..2ec752b4 100644 --- a/nx_arangodb/algorithms/centrality/betweenness.py +++ b/nx_arangodb/algorithms/centrality/betweenness.py @@ -4,12 +4,13 @@ from nx_arangodb.utils import networkx_algorithm try: - import pylibcugraph as plc - from nx_cugraph.utils import _seed_to_int + import nx_cugraph as nxcg GPU_ENABLED = True + print("ANTHONY: GPU is enabled") except ModuleNotFoundError: GPU_ENABLED = False + print("ANTHONY: GPU is disabled") __all__ = ["betweenness_centrality"] @@ -32,33 +33,20 @@ def betweenness_centrality( # 1. if GPU_ENABLED and run_on_gpu: - print("ANTHONY: GPU is enabled. Using nx-cugraph bc()") - - if weight is not None: - raise NotImplementedError( - "Weighted implementation of betweenness centrality not currently supported" - ) - - seed = _seed_to_int(seed) + print("ANTHONY: to_nxcg") G = _to_nxcg_graph(G, weight) - node_ids, values = plc.betweenness_centrality( - resource_handle=plc.ResourceHandle(), - graph=G._get_plc_graph(), - k=k, - random_state=seed, - normalized=normalized, - include_endpoints=endpoints, - do_expensive_check=False, - ) - return G._nodearrays_to_dict(node_ids, values) + print("ANTHONY: Using nxcg bc()") + return nxcg.betweenness_centrality(G, k=k, normalized=normalized, weight=weight) # 2. else: - print("ANTHONY: GPU is disabled. Using nx bc()") + print("ANTHONY: to_nxadb") G = _to_nxadb_graph(G) + print("ANTHONY: Using nx bc()") + betweenness = dict.fromkeys(G, 0.0) # b[v]=0 for v in G if k is None: nodes = G @@ -92,5 +80,3 @@ def betweenness_centrality( ) return betweenness - - # 3. TODO From c820d17122920514b02aa77b64c90ca42e1eba4e Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 14:19:50 -0400 Subject: [PATCH 17/27] new: `louvain` & `pagerank` --- nx_arangodb/algorithms/community/__init__.py | 0 nx_arangodb/algorithms/community/louvain.py | 44 ++++++++++++ .../algorithms/link_analysis/__init__.py | 0 .../algorithms/link_analysis/pagerank_alg.py | 70 +++++++++++++++++++ 4 files changed, 114 insertions(+) create mode 100644 nx_arangodb/algorithms/community/__init__.py create mode 100644 nx_arangodb/algorithms/community/louvain.py create mode 100644 nx_arangodb/algorithms/link_analysis/__init__.py create mode 100644 nx_arangodb/algorithms/link_analysis/pagerank_alg.py diff --git a/nx_arangodb/algorithms/community/__init__.py b/nx_arangodb/algorithms/community/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nx_arangodb/algorithms/community/louvain.py b/nx_arangodb/algorithms/community/louvain.py new file mode 100644 index 00000000..a593b848 --- /dev/null +++ b/nx_arangodb/algorithms/community/louvain.py @@ -0,0 +1,44 @@ +from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph +from nx_arangodb.utils import _dtype_param, networkx_algorithm + +try: + import nx_cugraph as nxcg + + GPU_ENABLED = True + print("ANTHONY: GPU is enabled") +except ModuleNotFoundError: + GPU_ENABLED = False + print("ANTHONY: GPU is disabled") + + +@networkx_algorithm( + extra_params={ + **_dtype_param, + }, + is_incomplete=True, # seed not supported; self-loops not supported + is_different=True, # RNG different + version_added="23.10", + _plc="louvain", + name="louvain_communities", +) +def louvain_communities( + G, weight="weight", resolution=1, threshold=0.0000001, max_level=None, seed=None +): + if GPU_ENABLED: + print("ANTHONY: to_nxcg") + G = _to_nxcg_graph(G, weight) + + print("ANTHONY: Using nxcg louvain()") + return nxcg._louvain_communities( + G, + weight=weight, + resolution=resolution, + threshold=threshold, + max_level=max_level, + seed=seed, + ) + + else: + raise NotImplementedError( + "Louvain community detection is not supported on CPU for nxadb" + ) diff --git a/nx_arangodb/algorithms/link_analysis/__init__.py b/nx_arangodb/algorithms/link_analysis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nx_arangodb/algorithms/link_analysis/pagerank_alg.py b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py new file mode 100644 index 00000000..6ca2ffbf --- /dev/null +++ b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py @@ -0,0 +1,70 @@ +from networkx.algorithms.link_analysis.pagerank_alg import _pagerank_scipy + +from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph +from nx_arangodb.utils import _dtype_param, networkx_algorithm + +try: + import nx_cugraph as nxcg + + GPU_ENABLED = True + print("ANTHONY: GPU is enabled") +except ModuleNotFoundError: + GPU_ENABLED = False + print("ANTHONY: GPU is disabled") + + +@networkx_algorithm( + extra_params=_dtype_param, + is_incomplete=True, # dangling not supported + version_added="23.12", + _plc={"pagerank", "personalized_pagerank"}, +) +def pagerank( + G, + alpha=0.85, + personalization=None, + max_iter=100, + tol=1.0e-6, + nstart=None, + weight="weight", + dangling=None, + *, + dtype=None, + run_on_gpu=True, +): + print("ANTHONY: Calling pagerank from nx_arangodb") + + # 1. + if GPU_ENABLED and run_on_gpu: + print("ANTHONY: to_nxcg") + G = _to_nxcg_graph(G, weight) + + print("ANTHONY: Using nxcg pagerank()") + return nxcg.pagerank( + G, + alpha=alpha, + personalization=personalization, + max_iter=max_iter, + tol=tol, + nstart=nstart, + weight=weight, + dangling=dangling, + dtype=dtype, + ) + + # 2. + else: + print("ANTHONY: to_nxadb") + G = _to_nxadb_graph(G) + + print("ANTHONY: Using nx pagerank()") + return _pagerank_scipy( + G, + alpha=alpha, + personalization=personalization, + max_iter=max_iter, + tol=tol, + nstart=nstart, + weight=weight, + dangling=dangling, + ) From 5e240f2791051efd85375a448af6ee7e9302745d Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 14:20:50 -0400 Subject: [PATCH 18/27] fix: condition --- nx_arangodb/convert.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 50c37581..d697ab97 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -286,8 +286,11 @@ def _from_networkx_arangodb( if G.is_multigraph(): raise NotImplementedError("Multigraphs not yet supported") - if G.coo_use_cache and all( - [G.src_indices, G.dst_indices, G.vertex_ids_to_index] + if ( + G.coo_use_cache + and G.src_indices is not None + and G.dst_indices is not None + and G.vertex_ids_to_index is not None ): src_indices = G.src_indices dst_indices = G.dst_indices From 083be060d3e93e01e5d05719ab25a9de06782642 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 15:18:36 -0400 Subject: [PATCH 19/27] update algorithms --- nx_arangodb/algorithms/community/__init__.py | 1 + nx_arangodb/algorithms/community/louvain.py | 105 +++++++++++++++++- .../algorithms/link_analysis/__init__.py | 1 + .../algorithms/link_analysis/pagerank_alg.py | 62 ++++++++++- 4 files changed, 164 insertions(+), 5 deletions(-) diff --git a/nx_arangodb/algorithms/community/__init__.py b/nx_arangodb/algorithms/community/__init__.py index e69de29b..5b43a3e4 100644 --- a/nx_arangodb/algorithms/community/__init__.py +++ b/nx_arangodb/algorithms/community/__init__.py @@ -0,0 +1 @@ +from .louvain import * diff --git a/nx_arangodb/algorithms/community/louvain.py b/nx_arangodb/algorithms/community/louvain.py index a593b848..7da284d9 100644 --- a/nx_arangodb/algorithms/community/louvain.py +++ b/nx_arangodb/algorithms/community/louvain.py @@ -1,3 +1,7 @@ +from collections import deque + +import networkx as nx + from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph from nx_arangodb.utils import _dtype_param, networkx_algorithm @@ -22,7 +26,12 @@ name="louvain_communities", ) def louvain_communities( - G, weight="weight", resolution=1, threshold=0.0000001, max_level=None, seed=None + G, + weight="weight", + resolution=1, + threshold=0.0000001, + max_level=None, + seed=None, ): if GPU_ENABLED: print("ANTHONY: to_nxcg") @@ -39,6 +48,96 @@ def louvain_communities( ) else: - raise NotImplementedError( - "Louvain community detection is not supported on CPU for nxadb" + print("ANTHONY: to_nxadb") + G = _to_nxadb_graph(G) + + print("ANTHONY: Using nx pagerank()") + import random + + d = louvain_partitions(G, weight, resolution, threshold, random.Random()) + q = deque(d, maxlen=1) + return q.pop() + + +@networkx_algorithm( + extra_params={ + **_dtype_param, + }, + is_incomplete=True, # seed not supported; self-loops not supported + is_different=True, # RNG different + version_added="23.10", + _plc="louvain", + name="louvain_partitions", +) +def louvain_partitions( + G, weight="weight", resolution=1, threshold=0.0000001, seed=None +): + partition = [{u} for u in G.nodes()] + if nx.is_empty(G): + yield partition + return + mod = modularity(G, partition, resolution=resolution, weight=weight) + is_directed = G.is_directed() + if G.is_multigraph(): + graph = nx.community._convert_multigraph(G, weight, is_directed) + else: + graph = G.__class__() + graph.add_nodes_from(G) + graph.add_weighted_edges_from(G.edges(data=weight, default=1)) + + m = graph.size(weight="weight") + partition, inner_partition, improvement = nx.community.louvain._one_level( + graph, m, partition, resolution, is_directed, seed + ) + improvement = True + while improvement: + # gh-5901 protect the sets in the yielded list from further manipulation here + yield [s.copy() for s in partition] + new_mod = modularity( + graph, inner_partition, resolution=resolution, weight="weight" ) + if new_mod - mod <= threshold: + return + mod = new_mod + graph = nx.community.louvain._gen_graph(graph, inner_partition) + partition, inner_partition, improvement = nx.community.louvain._one_level( + graph, m, partition, resolution, is_directed, seed + ) + + +@networkx_algorithm( + extra_params={ + **_dtype_param, + }, + is_incomplete=True, # seed not supported; self-loops not supported + is_different=True, # RNG different + version_added="23.10", +) +def modularity(G, communities, weight="weight", resolution=1): + if not isinstance(communities, list): + communities = list(communities) + # if not is_partition(G, communities): + # raise NotAPartition(G, communities) + + directed = G.is_directed() + if directed: + out_degree = dict(G.out_degree(weight=weight)) + in_degree = dict(G.in_degree(weight=weight)) + m = sum(out_degree.values()) + norm = 1 / m**2 + else: + out_degree = in_degree = dict(G.degree(weight=weight)) + deg_sum = sum(out_degree.values()) + m = deg_sum / 2 + norm = 1 / deg_sum**2 + + def community_contribution(community): + comm = set(community) + L_c = sum(wt for u, v, wt in G.edges(comm, data=weight, default=1) if v in comm) + + out_degree_sum = sum(out_degree[u] for u in comm) + in_degree_sum = sum(in_degree[u] for u in comm) if directed else out_degree_sum + + return L_c / m - resolution * out_degree_sum * in_degree_sum * norm + + return sum(map(community_contribution, communities)) diff --git a/nx_arangodb/algorithms/link_analysis/__init__.py b/nx_arangodb/algorithms/link_analysis/__init__.py index e69de29b..7e957e4f 100644 --- a/nx_arangodb/algorithms/link_analysis/__init__.py +++ b/nx_arangodb/algorithms/link_analysis/__init__.py @@ -0,0 +1 @@ +from .pagerank_alg import * diff --git a/nx_arangodb/algorithms/link_analysis/pagerank_alg.py b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py index 6ca2ffbf..d8f0212c 100644 --- a/nx_arangodb/algorithms/link_analysis/pagerank_alg.py +++ b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py @@ -1,4 +1,4 @@ -from networkx.algorithms.link_analysis.pagerank_alg import _pagerank_scipy +import networkx as nx from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph from nx_arangodb.utils import _dtype_param, networkx_algorithm @@ -58,7 +58,7 @@ def pagerank( G = _to_nxadb_graph(G) print("ANTHONY: Using nx pagerank()") - return _pagerank_scipy( + return nx.algorithms.link_analysis.pagerank_alg._pagerank_scipy( G, alpha=alpha, personalization=personalization, @@ -68,3 +68,61 @@ def pagerank( weight=weight, dangling=dangling, ) + + +@networkx_algorithm( + extra_params=_dtype_param, + version_added="23.12", +) +def to_scipy_sparse_array(G, nodelist=None, dtype=None, weight="weight", format="csr"): + import scipy as sp + + if len(G) == 0: + raise nx.NetworkXError("Graph has no nodes or edges") + + if nodelist is None: + nodelist = list(G) + nlen = len(G) + else: + nlen = len(nodelist) + if nlen == 0: + raise nx.NetworkXError("nodelist has no nodes") + nodeset = set(G.nbunch_iter(nodelist)) + if nlen != len(nodeset): + for n in nodelist: + if n not in G: + raise nx.NetworkXError(f"Node {n} in nodelist is not in G") + raise nx.NetworkXError("nodelist contains duplicates.") + if nlen < len(G): + G = G.subgraph(nodelist) + + index = dict(zip(nodelist, range(nlen))) + coefficients = zip( + *((index[u], index[v], wt) for u, v, wt in G.edges(data=weight, default=1)) + ) + try: + row, col, data = coefficients + except ValueError: + # there is no edge in the subgraph + row, col, data = [], [], [] + + if G.is_directed(): + A = sp.sparse.coo_array((data, (row, col)), shape=(nlen, nlen), dtype=dtype) + else: + # symmetrize matrix + d = data + data + r = row + col + c = col + row + # selfloop entries get double counted when symmetrizing + # so we subtract the data on the diagonal + selfloops = list(nx.selfloop_edges(G, data=weight, default=1)) + if selfloops: + diag_index, diag_data = zip(*((index[u], -wt) for u, v, wt in selfloops)) + d += diag_data + r += diag_index + c += diag_index + A = sp.sparse.coo_array((d, (r, c)), shape=(nlen, nlen), dtype=dtype) + try: + return A.asformat(format) + except ValueError as err: + raise nx.NetworkXError(f"Unknown sparse matrix format: {format}") from err From e60c36276c4c5c2f73255b6e15ad104a41a11921 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 15:18:53 -0400 Subject: [PATCH 20/27] cleanup --- _nx_arangodb/__init__.py | 24 +++++++++++++++++++++--- nx_arangodb/convert.py | 1 + nx_arangodb/utils/misc.py | 1 + tests/test.py | 35 ++++++++++++++++++++++++++++++----- 4 files changed, 53 insertions(+), 8 deletions(-) diff --git a/_nx_arangodb/__init__.py b/_nx_arangodb/__init__.py index 91366e48..0bf0e9b9 100644 --- a/_nx_arangodb/__init__.py +++ b/_nx_arangodb/__init__.py @@ -31,6 +31,11 @@ "functions": { # BEGIN: functions "betweenness_centrality", + "louvain_communities", + "louvain_partitions", + "modularity", + "pagerank", + "to_scipy_sparse_array", # END: functions }, "additional_docs": { @@ -40,7 +45,21 @@ }, "additional_parameters": { # BEGIN: additional_parameters - + "louvain_communities": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, + "louvain_partitions": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, + "modularity": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, + "pagerank": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, + "to_scipy_sparse_array": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, # END: additional_parameters }, } @@ -91,8 +110,7 @@ def __call__(self, *args, **kwargs): sys.modules["cupy"] = Stub() sys.modules["numpy"] = Stub() - # sys.modules["pylibcugraph"] = Stub() # TODO Anthony: re-introduce when ready - sys.modules["python-arango"] = Stub() # TODO Anthony: Double check + sys.modules["python-arango"] = Stub() from _nx_arangodb.core import main diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index d697ab97..a15cb098 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -292,6 +292,7 @@ def _from_networkx_arangodb( and G.dst_indices is not None and G.vertex_ids_to_index is not None ): + print("ANTHONY: Using cached COO") src_indices = G.src_indices dst_indices = G.dst_indices vertex_ids_to_index = G.vertex_ids_to_index diff --git a/nx_arangodb/utils/misc.py b/nx_arangodb/utils/misc.py index 105228fd..55e339d3 100644 --- a/nx_arangodb/utils/misc.py +++ b/nx_arangodb/utils/misc.py @@ -31,6 +31,7 @@ def pairwise(it): "index_dtype", "_seed_to_int", "_get_int_dtype", + "_dtype_param", ] # This may switch to np.uint32 at some point diff --git a/tests/test.py b/tests/test.py index cc7692f3..f21b176b 100644 --- a/tests/test.py +++ b/tests/test.py @@ -14,9 +14,34 @@ def test_bc(): G_2 = nxadb.Graph(G_1) - bc_1 = nx.betweenness_centrality(G_1) - bc_2 = nx.betweenness_centrality(G_2) - bc_3 = nx.betweenness_centrality(G_1, backend="arangodb") - bc_4 = nx.betweenness_centrality(G_2, backend="arangodb") + r_1 = nx.betweenness_centrality(G_1) + r_2 = nx.betweenness_centrality(G_2) + r_3 = nx.betweenness_centrality(G_1, backend="arangodb") + r_4 = nx.betweenness_centrality(G_2, backend="arangodb") - assert bc_1 == bc_2 == bc_3 == bc_4 \ No newline at end of file + assert r_1 and r_2 and r_3 and r_4 + +def test_pagerank(): + G_1 = nx.karate_club_graph() + + G_2 = nxadb.Graph(G_1) + + r_1 = nx.pagerank(G_1) + r_2 = nx.pagerank(G_2) + r_3 = nx.pagerank(G_1, backend="arangodb") + r_4 = nx.pagerank(G_2, backend="arangodb") + + assert r_1 and r_2 and r_3 and r_4 + + +def test_louvain(): + G_1 = nx.karate_club_graph() + + G_2 = nxadb.Graph(G_1) + + r_1 = nx.community.louvain_communities(G_1) + r_2 = nx.community.louvain_communities(G_2) + r_3 = nx.community.louvain_communities(G_1, backend="arangodb") + r_4 = nx.community.louvain_communities(G_2, backend="arangodb") + + assert r_1 and r_2 and r_3 and r_4 From 48ca5155bbeb333153b8f36d5245ff36aefcabcc Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 15:32:16 -0400 Subject: [PATCH 21/27] fix: bad import --- nx_arangodb/algorithms/community/louvain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nx_arangodb/algorithms/community/louvain.py b/nx_arangodb/algorithms/community/louvain.py index 7da284d9..5b65c988 100644 --- a/nx_arangodb/algorithms/community/louvain.py +++ b/nx_arangodb/algorithms/community/louvain.py @@ -38,7 +38,7 @@ def louvain_communities( G = _to_nxcg_graph(G, weight) print("ANTHONY: Using nxcg louvain()") - return nxcg._louvain_communities( + return nxcg.algorithms.community.louvain._louvain_communities( G, weight=weight, resolution=resolution, From 4b144317b0e75d62aa73eeefe7b51da03ea84a2d Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 22:06:24 -0400 Subject: [PATCH 22/27] cleanup: convert --- nx_arangodb/convert.py | 134 ++++++++--------------------------------- 1 file changed, 26 insertions(+), 108 deletions(-) diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index a15cb098..12f46c11 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -160,68 +160,34 @@ def to_networkx(G: nxadb.Graph, *, sort_edges: bool = False) -> nx.Graph: return G.to_networkx_class()(incoming_graph_data=G) -def _to_nxadb_graph( - G, - edge_attr: AttrKey | None = None, - edge_default: EdgeValue | None = 1, - edge_dtype: Dtype | None = None, -) -> nxadb.Graph | nxadb.DiGraph: - """Ensure that input type is a nx_arangodb graph, and convert if necessary. +def from_networkx_arangodb(G: nxadb.Graph) -> nxadb.Graph: + if not G.graph_exists: + print("ANTHONY: Graph does not exist, nothing to pull") + return G - Directed and undirected graphs are both allowed. - This is an internal utility function and may change or be removed. - """ - if isinstance(G, nxadb.Graph): + if G.use_node_and_adj_dict_cache and len(G.nodes) > 0 and len(G.adj) > 0: + print("ANTHONY: Using cached node and adj dict") return G - if isinstance(G, nx.Graph): - return from_networkx( - G, {edge_attr: edge_default} if edge_attr is not None else None, edge_dtype - ) - # TODO: handle cugraph.Graph - raise TypeError + start_time = time.time() + G.pull(load_coo=False) + end_time = time.time() -def _to_nxadb_directed_graph( - G, - edge_attr: AttrKey | None = None, - edge_default: EdgeValue | None = 1, - edge_dtype: Dtype | None = None, -) -> nxadb.DiGraph: - """Ensure that input type is a nx_arangodb DiGraph, and convert if necessary. + print("ANTHONY: Node & Adj Load took:", end_time - start_time) - Undirected graphs will be converted to directed. - This is an internal utility function and may change or be removed. - """ - if isinstance(G, nxadb.DiGraph): - return G - if isinstance(G, nxadb.Graph): - return G.to_directed() - if isinstance(G, nx.Graph): - return from_networkx( - G, - {edge_attr: edge_default} if edge_attr is not None else None, - edge_dtype, - as_directed=True, - ) - # TODO: handle cugraph.Graph - raise TypeError + return G -def _to_nxadb_undirected_graph( +def _to_nxadb_graph( G, edge_attr: AttrKey | None = None, edge_default: EdgeValue | None = 1, edge_dtype: Dtype | None = None, -) -> nxadb.Graph: - """Ensure that input type is a nx_arangodb Graph, and convert if necessary. - - Only undirected graphs are allowed. Directed graphs will raise ValueError. - This is an internal utility function and may change or be removed. - """ +) -> nxadb.Graph | nxadb.DiGraph: + """Ensure that input type is a nx_arangodb graph, and convert if necessary.""" if isinstance(G, nxadb.Graph): - if G.is_directed(): - raise ValueError("Only undirected graphs supported; got a directed graph") - return G + return from_networkx_arangodb(G) + if isinstance(G, nx.Graph): return from_networkx( G, {edge_attr: edge_default} if edge_attr is not None else None, edge_dtype @@ -237,7 +203,6 @@ def _to_nxadb_undirected_graph( import cupy as cp import numpy as np import nx_cugraph as nxcg - from phenolrs.coo_loader import CooLoader def _to_nxcg_graph( G, @@ -246,11 +211,7 @@ def _to_nxcg_graph( edge_dtype: Dtype | None = None, as_directed: bool = False, ) -> nxcg.Graph | nxcg.DiGraph: - """Ensure that input type is a nx_cugraph graph, and convert if necessary. - - Directed and undirected graphs are both allowed. - This is an internal utility function and may change or be removed. - """ + """Ensure that input type is a nx_cugraph graph, and convert if necessary.""" if isinstance(G, nxcg.Graph): return G if isinstance(G, nxadb.Graph): @@ -262,8 +223,8 @@ def _to_nxcg_graph( # the NetworkX graph to an nx_cugraph graph. # TODO: Implement a direct conversion from ArangoDB to nx_cugraph if G.graph_exists: - print("ANTHONY: Graph exists! Running _from_networkx_arangodb()") - return _from_networkx_arangodb(G, as_directed=as_directed) + print("ANTHONY: Graph exists, running _nxadb_to_nxcg()") + return _nxadb_to_nxcg(G, as_directed=as_directed) # If G is a networkx graph, or is a nxadb graph that doesn't point to an "existing" # ArangoDB graph, then we just treat it as a normal networkx graph & @@ -280,71 +241,28 @@ def _to_nxcg_graph( # TODO: handle cugraph.Graph raise TypeError - def _from_networkx_arangodb( + def _nxadb_to_nxcg( G: nxadb.Graph, as_directed: bool = False ) -> nxcg.Graph | nxcg.DiGraph: if G.is_multigraph(): raise NotImplementedError("Multigraphs not yet supported") if ( - G.coo_use_cache + G.use_coo_cache and G.src_indices is not None and G.dst_indices is not None and G.vertex_ids_to_index is not None ): print("ANTHONY: Using cached COO") - src_indices = G.src_indices - dst_indices = G.dst_indices - vertex_ids_to_index = G.vertex_ids_to_index else: - adb_graph = G.db.graph(G.graph_name) - - v_cols = adb_graph.vertex_collections() - edge_definitions = adb_graph.edge_definitions() - e_cols = {c["edge_collection"] for c in edge_definitions} - - metagraph = { - "vertexCollections": {col: {} for col in v_cols}, - "edgeCollections": {col: {} for col in e_cols}, - } - start_time = time.time() - - kwargs = {} - if G.coo_load_parallelism is not None: - kwargs["parallelism"] = G.coo_load_parallelism - if G.coo_load_batch_size is not None: - kwargs["batch_size"] = G.coo_load_batch_size - - src_indices, dst_indices, vertex_ids_to_index = CooLoader.load_coo( - G.db.name, - metagraph, - [os.environ["DATABASE_HOST"]], - username=os.environ["DATABASE_USERNAME"], - password=os.environ["DATABASE_PASSWORD"], - **kwargs, - ) - + G.pull(load_node_and_adj_dict=False) end_time = time.time() print("ANTHONY: COO Load took:", end_time - start_time) - start_time = time.time() - - src_indices = cp.array(src_indices) - dst_indices = cp.array(dst_indices) - - end_time = time.time() - - print("ANTHONY: cupy arrays took:", end_time - start_time) - - if G.coo_use_cache: - G.src_indices = src_indices - G.dst_indices = dst_indices - G.vertex_ids_to_index = vertex_ids_to_index - - N = len(vertex_ids_to_index) + N = len(G.vertex_ids_to_index) if G.is_directed() or as_directed: klass = nxcg.DiGraph @@ -355,9 +273,9 @@ def _from_networkx_arangodb( rv = klass.from_coo( N, - src_indices, - dst_indices, - key_to_id=vertex_ids_to_index, + cp.array(G.src_indices), + cp.array(G.dst_indices), + key_to_id=G.vertex_ids_to_index, ) end_time = time.time() From 742633b7f91a36412aa3a021caadaab8ac91225b Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 22:06:35 -0400 Subject: [PATCH 23/27] new: Graph `pull` method --- nx_arangodb/classes/graph.py | 73 +++++++++++++++++++++++++++++++----- 1 file changed, 63 insertions(+), 10 deletions(-) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 4a8f916e..26289fc7 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -31,9 +31,12 @@ def __init__( self.__graph_name = None self.__graph_exists = False - self.coo_use_cache = False - self.coo_load_parallelism = None - self.coo_load_batch_size = None + self.graph_loader_parallelism = None + self.graph_loader_batch_size = None + + self.use_node_and_adj_dict_cache = False + self.use_coo_cache = False + self.src_indices = None self.dst_indices = None self.vertex_ids_to_index = None @@ -75,19 +78,19 @@ def set_db(self, db: StandardDatabase | None = None): self.__db = db return - host = os.getenv("DATABASE_HOST") - username = os.getenv("DATABASE_USERNAME") - password = os.getenv("DATABASE_PASSWORD") - db_name = os.getenv("DATABASE_NAME") + self.__host = os.getenv("DATABASE_HOST") + self.__username = os.getenv("DATABASE_USERNAME") + self.__password = os.getenv("DATABASE_PASSWORD") + self.__db_name = os.getenv("DATABASE_NAME") # TODO: Raise a custom exception if any of the environment # variables are missing. For now, we'll just set db to None. - if not all([host, username, password, db_name]): + if not all([self.__host, self.__username, self.__password, self.__db_name]): self.__db = None return - self.__db = ArangoClient(hosts=host, request_timeout=None).db( - db_name, username, password, verify=True + self.__db = ArangoClient(hosts=self.__host, request_timeout=None).db( + self.__db_name, self.__username, self.__password, verify=True ) def set_graph_name(self, graph_name: str | None = None): @@ -112,3 +115,53 @@ def set_graph_name(self, graph_name: str | None = None): else: self.__graph_exists = True print(f"Found graph '{self.__graph_name}' in the database") + + def pull(self, load_node_and_adj_dict=True, load_coo=True): + if not self.graph_exists: + raise ValueError("Graph does not exist in the database") + + adb_graph = self.db.graph(self.graph_name) + + v_cols = adb_graph.vertex_collections() + edge_definitions = adb_graph.edge_definitions() + e_cols = {c["edge_collection"] for c in edge_definitions} + + metagraph = { + "vertexCollections": {col: {} for col in v_cols}, + "edgeCollections": {col: {} for col in e_cols}, + } + + from phenolrs.graph_loader import GraphLoader + + kwargs = {} + if self.graph_loader_parallelism is not None: + kwargs["parallelism"] = self.graph_loader_parallelism + if self.graph_loader_batch_size is not None: + kwargs["batch_size"] = self.graph_loader_batch_size + + result = GraphLoader.load( + self.db.name, + metagraph, + [os.environ["DATABASE_HOST"]], + username=os.environ["DATABASE_USERNAME"], + password=os.environ["DATABASE_PASSWORD"], + load_node_dict=load_node_and_adj_dict, + load_adj_dict=load_node_and_adj_dict, + load_coo=load_coo, + **kwargs, + ) + + if load_node_and_adj_dict: + # hacky, i don't like this + # need to revisit... + # consider using nx.convert.from_dict_of_dicts instead + self._node = result[0] + self._adj = result[1] + + if load_coo: + self.src_indices = result[2] + self.dst_indices = result[3] + self.vertex_ids_to_index = result[4] + + def push(self): + raise NotImplementedError("What would this look like?") From 2519512dea5837bb347f77b57d1d8bad7a2189a5 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 22:06:39 -0400 Subject: [PATCH 24/27] update `digraph` --- nx_arangodb/classes/digraph.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 9c8cd1e5..9c8b3ab6 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -1,4 +1,5 @@ import networkx as nx +from arango.database import StandardDatabase import nx_arangodb as nxadb from nx_arangodb.classes.graph import Graph @@ -24,9 +25,12 @@ def __init__( self.__graph_name = None self.__graph_exists = False - self.coo_use_cache = False - self.coo_load_parallelism = None - self.coo_load_batch_size = None + self.graph_loader_parallelism = None + self.graph_loader_batch_size = None + + self.use_node_and_adj_dict_cache = False + self.use_coo_cache = False + self.src_indices = None self.dst_indices = None self.vertex_ids_to_index = None @@ -35,6 +39,20 @@ def __init__( if self.__db is not None: self.set_graph_name() + @property + def db(self) -> StandardDatabase: + if self.__db is None: + raise ValueError("Database not set") + + return self.__db + + @property + def graph_name(self) -> str: + if self.__graph_name is None: + raise ValueError("Graph name not set") + + return self.__graph_name + @property def graph_exists(self) -> bool: return self.__graph_exists From 2b6dc347f297580a7cdec560994ada7357987bed Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 22:21:31 -0400 Subject: [PATCH 25/27] fix: missing param --- nx_arangodb/algorithms/community/louvain.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nx_arangodb/algorithms/community/louvain.py b/nx_arangodb/algorithms/community/louvain.py index 5b65c988..3aa77857 100644 --- a/nx_arangodb/algorithms/community/louvain.py +++ b/nx_arangodb/algorithms/community/louvain.py @@ -32,8 +32,9 @@ def louvain_communities( threshold=0.0000001, max_level=None, seed=None, + run_on_gpu=True, ): - if GPU_ENABLED: + if GPU_ENABLED and run_on_gpu: print("ANTHONY: to_nxcg") G = _to_nxcg_graph(G, weight) From 5bafb9d9a112b85e365ec8e2d19d4a76da2bd1d9 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 23:07:38 -0400 Subject: [PATCH 26/27] copy methods to digraph temporary workaround... --- nx_arangodb/classes/digraph.py | 125 +++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 9c8b3ab6..838f2dc9 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -1,4 +1,7 @@ +import os + import networkx as nx +from arango import ArangoClient from arango.database import StandardDatabase import nx_arangodb as nxadb @@ -56,3 +59,125 @@ def graph_name(self) -> str: @property def graph_exists(self) -> bool: return self.__graph_exists + + @property + def db(self) -> StandardDatabase: + if self.__db is None: + raise ValueError("Database not set") + + return self.__db + + @property + def graph_name(self) -> str: + if self.__graph_name is None: + raise ValueError("Graph name not set") + + return self.__graph_name + + @property + def graph_exists(self) -> bool: + return self.__graph_exists + + def clear_coo_cache(self): + self.src_indices = None + self.dst_indices = None + self.vertex_ids_to_index = None + + def set_db(self, db: StandardDatabase | None = None): + if db is not None: + if not isinstance(db, StandardDatabase): + raise TypeError( + "**db** must be an instance of arango.database.StandardDatabase" + ) + + self.__db = db + return + + self.__host = os.getenv("DATABASE_HOST") + self.__username = os.getenv("DATABASE_USERNAME") + self.__password = os.getenv("DATABASE_PASSWORD") + self.__db_name = os.getenv("DATABASE_NAME") + + # TODO: Raise a custom exception if any of the environment + # variables are missing. For now, we'll just set db to None. + if not all([self.__host, self.__username, self.__password, self.__db_name]): + self.__db = None + return + + self.__db = ArangoClient(hosts=self.__host, request_timeout=None).db( + self.__db_name, self.__username, self.__password, verify=True + ) + + def set_graph_name(self, graph_name: str | None = None): + if self.__db is None: + raise ValueError("Cannot set graph name without setting the database first") + + self.__graph_name = os.getenv("DATABASE_GRAPH_NAME") + if graph_name is not None: + if not isinstance(graph_name, str): + raise TypeError("**graph_name** must be a string") + + self.__graph_name = graph_name + + if self.__graph_name is None: + self.__graph_exists = False + print("DATABASE_GRAPH_NAME environment variable not set") + + elif not self.db.has_graph(self.__graph_name): + self.__graph_exists = False + print(f"Graph '{self.__graph_name}' does not exist in the database") + + else: + self.__graph_exists = True + print(f"Found graph '{self.__graph_name}' in the database") + + def pull(self, load_node_and_adj_dict=True, load_coo=True): + if not self.graph_exists: + raise ValueError("Graph does not exist in the database") + + adb_graph = self.db.graph(self.graph_name) + + v_cols = adb_graph.vertex_collections() + edge_definitions = adb_graph.edge_definitions() + e_cols = {c["edge_collection"] for c in edge_definitions} + + metagraph = { + "vertexCollections": {col: {} for col in v_cols}, + "edgeCollections": {col: {} for col in e_cols}, + } + + from phenolrs.graph_loader import GraphLoader + + kwargs = {} + if self.graph_loader_parallelism is not None: + kwargs["parallelism"] = self.graph_loader_parallelism + if self.graph_loader_batch_size is not None: + kwargs["batch_size"] = self.graph_loader_batch_size + + result = GraphLoader.load( + self.db.name, + metagraph, + [self.__host], + username=self.__username, + password=self.__password, + load_node_dict=load_node_and_adj_dict, + load_adj_dict=load_node_and_adj_dict, + load_adj_dict_as_undirected=False, + load_coo=load_coo, + **kwargs, + ) + + if load_node_and_adj_dict: + # hacky, i don't like this + # need to revisit... + # consider using nx.convert.from_dict_of_dicts instead + self._node = result[0] + self._adj = result[1] + + if load_coo: + self.src_indices = result[2] + self.dst_indices = result[3] + self.vertex_ids_to_index = result[4] + + def push(self): + raise NotImplementedError("What would this look like?") From 334b960b852fcfa6ea8433ae6af362f9f74182b3 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 1 May 2024 23:07:47 -0400 Subject: [PATCH 27/27] new: `load_adj_dict_as_undirected` --- nx_arangodb/classes/graph.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 26289fc7..71ea891d 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -142,11 +142,12 @@ def pull(self, load_node_and_adj_dict=True, load_coo=True): result = GraphLoader.load( self.db.name, metagraph, - [os.environ["DATABASE_HOST"]], - username=os.environ["DATABASE_USERNAME"], - password=os.environ["DATABASE_PASSWORD"], + [self.__host], + username=self.__username, + password=self.__password, load_node_dict=load_node_and_adj_dict, load_adj_dict=load_node_and_adj_dict, + load_adj_dict_as_undirected=True, load_coo=load_coo, **kwargs, )