diff --git a/_nx_arangodb/__init__.py b/_nx_arangodb/__init__.py index 91366e48..0bf0e9b9 100644 --- a/_nx_arangodb/__init__.py +++ b/_nx_arangodb/__init__.py @@ -31,6 +31,11 @@ "functions": { # BEGIN: functions "betweenness_centrality", + "louvain_communities", + "louvain_partitions", + "modularity", + "pagerank", + "to_scipy_sparse_array", # END: functions }, "additional_docs": { @@ -40,7 +45,21 @@ }, "additional_parameters": { # BEGIN: additional_parameters - + "louvain_communities": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, + "louvain_partitions": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, + "modularity": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, + "pagerank": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, + "to_scipy_sparse_array": { + "dtype : dtype or None, optional": "The data type (np.float32, np.float64, or None) to use for the edge weights in the algorithm. If None, then dtype is determined by the edge values.", + }, # END: additional_parameters }, } @@ -91,8 +110,7 @@ def __call__(self, *args, **kwargs): sys.modules["cupy"] = Stub() sys.modules["numpy"] = Stub() - # sys.modules["pylibcugraph"] = Stub() # TODO Anthony: re-introduce when ready - sys.modules["python-arango"] = Stub() # TODO Anthony: Double check + sys.modules["python-arango"] = Stub() from _nx_arangodb.core import main diff --git a/nx_arangodb/__init__.py b/nx_arangodb/__init__.py index 4400a9f8..fcb0b7ac 100644 --- a/nx_arangodb/__init__.py +++ b/nx_arangodb/__init__.py @@ -10,10 +10,6 @@ from . import convert from .convert import * -# TODO Anthony: Do we need this? -# from . import convert_matrix -# from .convert_matrix import * - from . import algorithms from .algorithms import * diff --git a/nx_arangodb/algorithms/__init__.py b/nx_arangodb/algorithms/__init__.py index 570d5dd6..60bb5633 100644 --- a/nx_arangodb/algorithms/__init__.py +++ b/nx_arangodb/algorithms/__init__.py @@ -1,2 +1,4 @@ -from . import centrality +from . import centrality, community, link_analysis from .centrality import * +from .community import * +from .link_analysis import * diff --git a/nx_arangodb/algorithms/centrality/betweenness.py b/nx_arangodb/algorithms/centrality/betweenness.py index c8bd0f1b..2ec752b4 100644 --- a/nx_arangodb/algorithms/centrality/betweenness.py +++ b/nx_arangodb/algorithms/centrality/betweenness.py @@ -1,23 +1,23 @@ from networkx.algorithms.centrality import betweenness as nx_betweenness -from nx_arangodb.convert import _to_graph as _to_nx_arangodb_graph +from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph from nx_arangodb.utils import networkx_algorithm try: - import pylibcugraph as plc - from nx_cugraph.convert import _to_graph as _to_nx_cugraph_graph - from nx_cugraph.utils import _seed_to_int + import nx_cugraph as nxcg GPU_ENABLED = True + print("ANTHONY: GPU is enabled") except ModuleNotFoundError: GPU_ENABLED = False + print("ANTHONY: GPU is disabled") __all__ = ["betweenness_centrality"] -# 1. If GPU is enabled, call nx-cugraph bc() after converting to a nx_cugraph graph (in-memory graph) -# 2. If GPU is not enabled, call networkx bc() after converting to a networkx graph (in-memory graph) -# 3. If GPU is not enabled, call networkx bc() **without** converting to a networkx graph (remote graph) +# 1. If GPU is enabled, call nx-cugraph bc() after converting to an ncxg graph (in-memory graph) +# 2. If GPU is not enabled, call networkx bc() after converting to an nxadb graph (in-memory graph) +# 3. If GPU is not enabled, call networkx bc() **without** converting to a nxadb graph (remote graph) @networkx_algorithm( @@ -33,32 +33,19 @@ def betweenness_centrality( # 1. if GPU_ENABLED and run_on_gpu: - print("ANTHONY: GPU is enabled. Using nx-cugraph bc()") - - if weight is not None: - raise NotImplementedError( - "Weighted implementation of betweenness centrality not currently supported" - ) - - seed = _seed_to_int(seed) - G = _to_nx_cugraph_graph(G, weight) - node_ids, values = plc.betweenness_centrality( - resource_handle=plc.ResourceHandle(), - graph=G._get_plc_graph(), - k=k, - random_state=seed, - normalized=normalized, - include_endpoints=endpoints, - do_expensive_check=False, - ) + print("ANTHONY: to_nxcg") + G = _to_nxcg_graph(G, weight) - return G._nodearrays_to_dict(node_ids, values) + print("ANTHONY: Using nxcg bc()") + return nxcg.betweenness_centrality(G, k=k, normalized=normalized, weight=weight) # 2. else: - print("ANTHONY: GPU is disabled. Using nx bc()") - G = _to_nx_arangodb_graph(G) + print("ANTHONY: to_nxadb") + G = _to_nxadb_graph(G) + + print("ANTHONY: Using nx bc()") betweenness = dict.fromkeys(G, 0.0) # b[v]=0 for v in G if k is None: @@ -93,5 +80,3 @@ def betweenness_centrality( ) return betweenness - - # 3. TODO diff --git a/nx_arangodb/algorithms/community/__init__.py b/nx_arangodb/algorithms/community/__init__.py new file mode 100644 index 00000000..5b43a3e4 --- /dev/null +++ b/nx_arangodb/algorithms/community/__init__.py @@ -0,0 +1 @@ +from .louvain import * diff --git a/nx_arangodb/algorithms/community/louvain.py b/nx_arangodb/algorithms/community/louvain.py new file mode 100644 index 00000000..3aa77857 --- /dev/null +++ b/nx_arangodb/algorithms/community/louvain.py @@ -0,0 +1,144 @@ +from collections import deque + +import networkx as nx + +from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph +from nx_arangodb.utils import _dtype_param, networkx_algorithm + +try: + import nx_cugraph as nxcg + + GPU_ENABLED = True + print("ANTHONY: GPU is enabled") +except ModuleNotFoundError: + GPU_ENABLED = False + print("ANTHONY: GPU is disabled") + + +@networkx_algorithm( + extra_params={ + **_dtype_param, + }, + is_incomplete=True, # seed not supported; self-loops not supported + is_different=True, # RNG different + version_added="23.10", + _plc="louvain", + name="louvain_communities", +) +def louvain_communities( + G, + weight="weight", + resolution=1, + threshold=0.0000001, + max_level=None, + seed=None, + run_on_gpu=True, +): + if GPU_ENABLED and run_on_gpu: + print("ANTHONY: to_nxcg") + G = _to_nxcg_graph(G, weight) + + print("ANTHONY: Using nxcg louvain()") + return nxcg.algorithms.community.louvain._louvain_communities( + G, + weight=weight, + resolution=resolution, + threshold=threshold, + max_level=max_level, + seed=seed, + ) + + else: + print("ANTHONY: to_nxadb") + G = _to_nxadb_graph(G) + + print("ANTHONY: Using nx pagerank()") + import random + + d = louvain_partitions(G, weight, resolution, threshold, random.Random()) + q = deque(d, maxlen=1) + return q.pop() + + +@networkx_algorithm( + extra_params={ + **_dtype_param, + }, + is_incomplete=True, # seed not supported; self-loops not supported + is_different=True, # RNG different + version_added="23.10", + _plc="louvain", + name="louvain_partitions", +) +def louvain_partitions( + G, weight="weight", resolution=1, threshold=0.0000001, seed=None +): + partition = [{u} for u in G.nodes()] + if nx.is_empty(G): + yield partition + return + mod = modularity(G, partition, resolution=resolution, weight=weight) + is_directed = G.is_directed() + if G.is_multigraph(): + graph = nx.community._convert_multigraph(G, weight, is_directed) + else: + graph = G.__class__() + graph.add_nodes_from(G) + graph.add_weighted_edges_from(G.edges(data=weight, default=1)) + + m = graph.size(weight="weight") + partition, inner_partition, improvement = nx.community.louvain._one_level( + graph, m, partition, resolution, is_directed, seed + ) + improvement = True + while improvement: + # gh-5901 protect the sets in the yielded list from further manipulation here + yield [s.copy() for s in partition] + new_mod = modularity( + graph, inner_partition, resolution=resolution, weight="weight" + ) + if new_mod - mod <= threshold: + return + mod = new_mod + graph = nx.community.louvain._gen_graph(graph, inner_partition) + partition, inner_partition, improvement = nx.community.louvain._one_level( + graph, m, partition, resolution, is_directed, seed + ) + + +@networkx_algorithm( + extra_params={ + **_dtype_param, + }, + is_incomplete=True, # seed not supported; self-loops not supported + is_different=True, # RNG different + version_added="23.10", +) +def modularity(G, communities, weight="weight", resolution=1): + if not isinstance(communities, list): + communities = list(communities) + # if not is_partition(G, communities): + # raise NotAPartition(G, communities) + + directed = G.is_directed() + if directed: + out_degree = dict(G.out_degree(weight=weight)) + in_degree = dict(G.in_degree(weight=weight)) + m = sum(out_degree.values()) + norm = 1 / m**2 + else: + out_degree = in_degree = dict(G.degree(weight=weight)) + deg_sum = sum(out_degree.values()) + m = deg_sum / 2 + norm = 1 / deg_sum**2 + + def community_contribution(community): + comm = set(community) + L_c = sum(wt for u, v, wt in G.edges(comm, data=weight, default=1) if v in comm) + + out_degree_sum = sum(out_degree[u] for u in comm) + in_degree_sum = sum(in_degree[u] for u in comm) if directed else out_degree_sum + + return L_c / m - resolution * out_degree_sum * in_degree_sum * norm + + return sum(map(community_contribution, communities)) diff --git a/nx_arangodb/algorithms/link_analysis/__init__.py b/nx_arangodb/algorithms/link_analysis/__init__.py new file mode 100644 index 00000000..7e957e4f --- /dev/null +++ b/nx_arangodb/algorithms/link_analysis/__init__.py @@ -0,0 +1 @@ +from .pagerank_alg import * diff --git a/nx_arangodb/algorithms/link_analysis/pagerank_alg.py b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py new file mode 100644 index 00000000..d8f0212c --- /dev/null +++ b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py @@ -0,0 +1,128 @@ +import networkx as nx + +from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph +from nx_arangodb.utils import _dtype_param, networkx_algorithm + +try: + import nx_cugraph as nxcg + + GPU_ENABLED = True + print("ANTHONY: GPU is enabled") +except ModuleNotFoundError: + GPU_ENABLED = False + print("ANTHONY: GPU is disabled") + + +@networkx_algorithm( + extra_params=_dtype_param, + is_incomplete=True, # dangling not supported + version_added="23.12", + _plc={"pagerank", "personalized_pagerank"}, +) +def pagerank( + G, + alpha=0.85, + personalization=None, + max_iter=100, + tol=1.0e-6, + nstart=None, + weight="weight", + dangling=None, + *, + dtype=None, + run_on_gpu=True, +): + print("ANTHONY: Calling pagerank from nx_arangodb") + + # 1. + if GPU_ENABLED and run_on_gpu: + print("ANTHONY: to_nxcg") + G = _to_nxcg_graph(G, weight) + + print("ANTHONY: Using nxcg pagerank()") + return nxcg.pagerank( + G, + alpha=alpha, + personalization=personalization, + max_iter=max_iter, + tol=tol, + nstart=nstart, + weight=weight, + dangling=dangling, + dtype=dtype, + ) + + # 2. + else: + print("ANTHONY: to_nxadb") + G = _to_nxadb_graph(G) + + print("ANTHONY: Using nx pagerank()") + return nx.algorithms.link_analysis.pagerank_alg._pagerank_scipy( + G, + alpha=alpha, + personalization=personalization, + max_iter=max_iter, + tol=tol, + nstart=nstart, + weight=weight, + dangling=dangling, + ) + + +@networkx_algorithm( + extra_params=_dtype_param, + version_added="23.12", +) +def to_scipy_sparse_array(G, nodelist=None, dtype=None, weight="weight", format="csr"): + import scipy as sp + + if len(G) == 0: + raise nx.NetworkXError("Graph has no nodes or edges") + + if nodelist is None: + nodelist = list(G) + nlen = len(G) + else: + nlen = len(nodelist) + if nlen == 0: + raise nx.NetworkXError("nodelist has no nodes") + nodeset = set(G.nbunch_iter(nodelist)) + if nlen != len(nodeset): + for n in nodelist: + if n not in G: + raise nx.NetworkXError(f"Node {n} in nodelist is not in G") + raise nx.NetworkXError("nodelist contains duplicates.") + if nlen < len(G): + G = G.subgraph(nodelist) + + index = dict(zip(nodelist, range(nlen))) + coefficients = zip( + *((index[u], index[v], wt) for u, v, wt in G.edges(data=weight, default=1)) + ) + try: + row, col, data = coefficients + except ValueError: + # there is no edge in the subgraph + row, col, data = [], [], [] + + if G.is_directed(): + A = sp.sparse.coo_array((data, (row, col)), shape=(nlen, nlen), dtype=dtype) + else: + # symmetrize matrix + d = data + data + r = row + col + c = col + row + # selfloop entries get double counted when symmetrizing + # so we subtract the data on the diagonal + selfloops = list(nx.selfloop_edges(G, data=weight, default=1)) + if selfloops: + diag_index, diag_data = zip(*((index[u], -wt) for u, v, wt in selfloops)) + d += diag_data + r += diag_index + c += diag_index + A = sp.sparse.coo_array((d, (r, c)), shape=(nlen, nlen), dtype=dtype) + try: + return A.asformat(format) + except ValueError as err: + raise nx.NetworkXError(f"Unknown sparse matrix format: {format}") from err diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index a3add415..838f2dc9 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -1,4 +1,8 @@ +import os + import networkx as nx +from arango import ArangoClient +from arango.database import StandardDatabase import nx_arangodb as nxadb from nx_arangodb.classes.graph import Graph @@ -12,3 +16,168 @@ class DiGraph(nx.DiGraph, Graph): @classmethod def to_networkx_class(cls) -> type[nx.DiGraph]: return nx.DiGraph + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.__db = None + self.__graph_name = None + self.__graph_exists = False + + self.graph_loader_parallelism = None + self.graph_loader_batch_size = None + + self.use_node_and_adj_dict_cache = False + self.use_coo_cache = False + + self.src_indices = None + self.dst_indices = None + self.vertex_ids_to_index = None + + self.set_db() + if self.__db is not None: + self.set_graph_name() + + @property + def db(self) -> StandardDatabase: + if self.__db is None: + raise ValueError("Database not set") + + return self.__db + + @property + def graph_name(self) -> str: + if self.__graph_name is None: + raise ValueError("Graph name not set") + + return self.__graph_name + + @property + def graph_exists(self) -> bool: + return self.__graph_exists + + @property + def db(self) -> StandardDatabase: + if self.__db is None: + raise ValueError("Database not set") + + return self.__db + + @property + def graph_name(self) -> str: + if self.__graph_name is None: + raise ValueError("Graph name not set") + + return self.__graph_name + + @property + def graph_exists(self) -> bool: + return self.__graph_exists + + def clear_coo_cache(self): + self.src_indices = None + self.dst_indices = None + self.vertex_ids_to_index = None + + def set_db(self, db: StandardDatabase | None = None): + if db is not None: + if not isinstance(db, StandardDatabase): + raise TypeError( + "**db** must be an instance of arango.database.StandardDatabase" + ) + + self.__db = db + return + + self.__host = os.getenv("DATABASE_HOST") + self.__username = os.getenv("DATABASE_USERNAME") + self.__password = os.getenv("DATABASE_PASSWORD") + self.__db_name = os.getenv("DATABASE_NAME") + + # TODO: Raise a custom exception if any of the environment + # variables are missing. For now, we'll just set db to None. + if not all([self.__host, self.__username, self.__password, self.__db_name]): + self.__db = None + return + + self.__db = ArangoClient(hosts=self.__host, request_timeout=None).db( + self.__db_name, self.__username, self.__password, verify=True + ) + + def set_graph_name(self, graph_name: str | None = None): + if self.__db is None: + raise ValueError("Cannot set graph name without setting the database first") + + self.__graph_name = os.getenv("DATABASE_GRAPH_NAME") + if graph_name is not None: + if not isinstance(graph_name, str): + raise TypeError("**graph_name** must be a string") + + self.__graph_name = graph_name + + if self.__graph_name is None: + self.__graph_exists = False + print("DATABASE_GRAPH_NAME environment variable not set") + + elif not self.db.has_graph(self.__graph_name): + self.__graph_exists = False + print(f"Graph '{self.__graph_name}' does not exist in the database") + + else: + self.__graph_exists = True + print(f"Found graph '{self.__graph_name}' in the database") + + def pull(self, load_node_and_adj_dict=True, load_coo=True): + if not self.graph_exists: + raise ValueError("Graph does not exist in the database") + + adb_graph = self.db.graph(self.graph_name) + + v_cols = adb_graph.vertex_collections() + edge_definitions = adb_graph.edge_definitions() + e_cols = {c["edge_collection"] for c in edge_definitions} + + metagraph = { + "vertexCollections": {col: {} for col in v_cols}, + "edgeCollections": {col: {} for col in e_cols}, + } + + from phenolrs.graph_loader import GraphLoader + + kwargs = {} + if self.graph_loader_parallelism is not None: + kwargs["parallelism"] = self.graph_loader_parallelism + if self.graph_loader_batch_size is not None: + kwargs["batch_size"] = self.graph_loader_batch_size + + result = GraphLoader.load( + self.db.name, + metagraph, + [self.__host], + username=self.__username, + password=self.__password, + load_node_dict=load_node_and_adj_dict, + load_adj_dict=load_node_and_adj_dict, + load_adj_dict_as_undirected=False, + load_coo=load_coo, + **kwargs, + ) + + if load_node_and_adj_dict: + # hacky, i don't like this + # need to revisit... + # consider using nx.convert.from_dict_of_dicts instead + self._node = result[0] + self._adj = result[1] + + if load_coo: + self.src_indices = result[2] + self.dst_indices = result[3] + self.vertex_ids_to_index = result[4] + + def push(self): + raise NotImplementedError("What would this look like?") diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index db495e36..71ea891d 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -1,6 +1,9 @@ +import os from typing import ClassVar import networkx as nx +from arango import ArangoClient +from arango.database import StandardDatabase import nx_arangodb as nxadb @@ -16,3 +19,150 @@ class Graph(nx.Graph): @classmethod def to_networkx_class(cls) -> type[nx.Graph]: return nx.Graph + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.__db = None + self.__graph_name = None + self.__graph_exists = False + + self.graph_loader_parallelism = None + self.graph_loader_batch_size = None + + self.use_node_and_adj_dict_cache = False + self.use_coo_cache = False + + self.src_indices = None + self.dst_indices = None + self.vertex_ids_to_index = None + + self.set_db() + if self.__db is not None: + self.set_graph_name() + + @property + def db(self) -> StandardDatabase: + if self.__db is None: + raise ValueError("Database not set") + + return self.__db + + @property + def graph_name(self) -> str: + if self.__graph_name is None: + raise ValueError("Graph name not set") + + return self.__graph_name + + @property + def graph_exists(self) -> bool: + return self.__graph_exists + + def clear_coo_cache(self): + self.src_indices = None + self.dst_indices = None + self.vertex_ids_to_index = None + + def set_db(self, db: StandardDatabase | None = None): + if db is not None: + if not isinstance(db, StandardDatabase): + raise TypeError( + "**db** must be an instance of arango.database.StandardDatabase" + ) + + self.__db = db + return + + self.__host = os.getenv("DATABASE_HOST") + self.__username = os.getenv("DATABASE_USERNAME") + self.__password = os.getenv("DATABASE_PASSWORD") + self.__db_name = os.getenv("DATABASE_NAME") + + # TODO: Raise a custom exception if any of the environment + # variables are missing. For now, we'll just set db to None. + if not all([self.__host, self.__username, self.__password, self.__db_name]): + self.__db = None + return + + self.__db = ArangoClient(hosts=self.__host, request_timeout=None).db( + self.__db_name, self.__username, self.__password, verify=True + ) + + def set_graph_name(self, graph_name: str | None = None): + if self.__db is None: + raise ValueError("Cannot set graph name without setting the database first") + + self.__graph_name = os.getenv("DATABASE_GRAPH_NAME") + if graph_name is not None: + if not isinstance(graph_name, str): + raise TypeError("**graph_name** must be a string") + + self.__graph_name = graph_name + + if self.__graph_name is None: + self.__graph_exists = False + print("DATABASE_GRAPH_NAME environment variable not set") + + elif not self.db.has_graph(self.__graph_name): + self.__graph_exists = False + print(f"Graph '{self.__graph_name}' does not exist in the database") + + else: + self.__graph_exists = True + print(f"Found graph '{self.__graph_name}' in the database") + + def pull(self, load_node_and_adj_dict=True, load_coo=True): + if not self.graph_exists: + raise ValueError("Graph does not exist in the database") + + adb_graph = self.db.graph(self.graph_name) + + v_cols = adb_graph.vertex_collections() + edge_definitions = adb_graph.edge_definitions() + e_cols = {c["edge_collection"] for c in edge_definitions} + + metagraph = { + "vertexCollections": {col: {} for col in v_cols}, + "edgeCollections": {col: {} for col in e_cols}, + } + + from phenolrs.graph_loader import GraphLoader + + kwargs = {} + if self.graph_loader_parallelism is not None: + kwargs["parallelism"] = self.graph_loader_parallelism + if self.graph_loader_batch_size is not None: + kwargs["batch_size"] = self.graph_loader_batch_size + + result = GraphLoader.load( + self.db.name, + metagraph, + [self.__host], + username=self.__username, + password=self.__password, + load_node_dict=load_node_and_adj_dict, + load_adj_dict=load_node_and_adj_dict, + load_adj_dict_as_undirected=True, + load_coo=load_coo, + **kwargs, + ) + + if load_node_and_adj_dict: + # hacky, i don't like this + # need to revisit... + # consider using nx.convert.from_dict_of_dicts instead + self._node = result[0] + self._adj = result[1] + + if load_coo: + self.src_indices = result[2] + self.dst_indices = result[3] + self.vertex_ids_to_index = result[4] + + def push(self): + raise NotImplementedError("What would this look like?") diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py index 3b83aea5..e2f86584 100644 --- a/nx_arangodb/classes/multidigraph.py +++ b/nx_arangodb/classes/multidigraph.py @@ -13,3 +13,21 @@ class MultiDiGraph(nx.MultiDiGraph, MultiGraph, DiGraph): @classmethod def to_networkx_class(cls) -> type[nx.MultiDiGraph]: return nx.MultiDiGraph + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.__db = None + self.__graph_name = None + self.__graph_exists = False + + self.coo_load_parallelism = None + self.coo_load_batch_size = None + + self.set_db() + if self.__db is not None: + self.set_graph_name() diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index 6afe0c63..0a252284 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -12,3 +12,21 @@ class MultiGraph(nx.MultiGraph, Graph): @classmethod def to_networkx_class(cls) -> type[nx.MultiGraph]: return nx.MultiGraph + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.__db = None + self.__graph_name = None + self.__graph_exists = False + + self.coo_load_parallelism = None + self.coo_load_batch_size = None + + self.set_db() + if self.__db is not None: + self.set_graph_name() diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index fae54ce4..12f46c11 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -3,21 +3,14 @@ from __future__ import annotations import itertools -import operator as op -from collections import Counter -from collections.abc import Mapping from typing import TYPE_CHECKING -# import cupy as cp import networkx as nx -import numpy as np import nx_arangodb as nxadb -from .utils import index_dtype - if TYPE_CHECKING: # pragma: no cover - from nx_arangodb.typing import AttrKey, Dtype, EdgeValue, NodeValue, any_ndarray + from nx_arangodb.typing import AttrKey, Dtype, EdgeValue, NodeValue __all__ = [ "from_networkx", @@ -167,19 +160,34 @@ def to_networkx(G: nxadb.Graph, *, sort_edges: bool = False) -> nx.Graph: return G.to_networkx_class()(incoming_graph_data=G) -def _to_graph( +def from_networkx_arangodb(G: nxadb.Graph) -> nxadb.Graph: + if not G.graph_exists: + print("ANTHONY: Graph does not exist, nothing to pull") + return G + + if G.use_node_and_adj_dict_cache and len(G.nodes) > 0 and len(G.adj) > 0: + print("ANTHONY: Using cached node and adj dict") + return G + + start_time = time.time() + G.pull(load_coo=False) + end_time = time.time() + + print("ANTHONY: Node & Adj Load took:", end_time - start_time) + + return G + + +def _to_nxadb_graph( G, edge_attr: AttrKey | None = None, edge_default: EdgeValue | None = 1, edge_dtype: Dtype | None = None, ) -> nxadb.Graph | nxadb.DiGraph: - """Ensure that input type is a nx_arangodb graph, and convert if necessary. - - Directed and undirected graphs are both allowed. - This is an internal utility function and may change or be removed. - """ + """Ensure that input type is a nx_arangodb graph, and convert if necessary.""" if isinstance(G, nxadb.Graph): - return G + return from_networkx_arangodb(G) + if isinstance(G, nx.Graph): return from_networkx( G, {edge_attr: edge_default} if edge_attr is not None else None, edge_dtype @@ -188,50 +196,102 @@ def _to_graph( raise TypeError -def _to_directed_graph( - G, - edge_attr: AttrKey | None = None, - edge_default: EdgeValue | None = 1, - edge_dtype: Dtype | None = None, -) -> nxadb.DiGraph: - """Ensure that input type is a nx_arangodb DiGraph, and convert if necessary. +try: + import os + import time + + import cupy as cp + import numpy as np + import nx_cugraph as nxcg + + def _to_nxcg_graph( + G, + edge_attr: AttrKey | None = None, + edge_default: EdgeValue | None = 1, + edge_dtype: Dtype | None = None, + as_directed: bool = False, + ) -> nxcg.Graph | nxcg.DiGraph: + """Ensure that input type is a nx_cugraph graph, and convert if necessary.""" + if isinstance(G, nxcg.Graph): + return G + if isinstance(G, nxadb.Graph): + # Assumption: G.adb_graph_name points to an existing graph in ArangoDB + # Therefore, the user wants us to pull the graph from ArangoDB, + # and convert it to an nx_cugraph graph. + # We currently accomplish this by using the NetworkX adapter for ArangoDB, + # which converts the ArangoDB graph to a NetworkX graph, and then we convert + # the NetworkX graph to an nx_cugraph graph. + # TODO: Implement a direct conversion from ArangoDB to nx_cugraph + if G.graph_exists: + print("ANTHONY: Graph exists, running _nxadb_to_nxcg()") + return _nxadb_to_nxcg(G, as_directed=as_directed) + + # If G is a networkx graph, or is a nxadb graph that doesn't point to an "existing" + # ArangoDB graph, then we just treat it as a normal networkx graph & + # convert it to nx_cugraph. + # TODO: Need to revisit the "existing" ArangoDB graph condition... + if isinstance(G, nx.Graph): + return nxcg.convert.from_networkx( + G, + {edge_attr: edge_default} if edge_attr is not None else None, + edge_dtype, + as_directed=as_directed, + ) + + # TODO: handle cugraph.Graph + raise TypeError + + def _nxadb_to_nxcg( + G: nxadb.Graph, as_directed: bool = False + ) -> nxcg.Graph | nxcg.DiGraph: + if G.is_multigraph(): + raise NotImplementedError("Multigraphs not yet supported") + + if ( + G.use_coo_cache + and G.src_indices is not None + and G.dst_indices is not None + and G.vertex_ids_to_index is not None + ): + print("ANTHONY: Using cached COO") - Undirected graphs will be converted to directed. - This is an internal utility function and may change or be removed. - """ - if isinstance(G, nxadb.DiGraph): - return G - if isinstance(G, nxadb.Graph): - return G.to_directed() - if isinstance(G, nx.Graph): - return from_networkx( - G, - {edge_attr: edge_default} if edge_attr is not None else None, - edge_dtype, - as_directed=True, - ) - # TODO: handle cugraph.Graph - raise TypeError + else: + start_time = time.time() + G.pull(load_node_and_adj_dict=False) + end_time = time.time() + print("ANTHONY: COO Load took:", end_time - start_time) -def _to_undirected_graph( - G, - edge_attr: AttrKey | None = None, - edge_default: EdgeValue | None = 1, - edge_dtype: Dtype | None = None, -) -> nxadb.Graph: - """Ensure that input type is a nx_arangodb Graph, and convert if necessary. + N = len(G.vertex_ids_to_index) - Only undirected graphs are allowed. Directed graphs will raise ValueError. - This is an internal utility function and may change or be removed. - """ - if isinstance(G, nxadb.Graph): - if G.is_directed(): - raise ValueError("Only undirected graphs supported; got a directed graph") - return G - if isinstance(G, nx.Graph): - return from_networkx( - G, {edge_attr: edge_default} if edge_attr is not None else None, edge_dtype + if G.is_directed() or as_directed: + klass = nxcg.DiGraph + else: + klass = nxcg.Graph + + start_time = time.time() + + rv = klass.from_coo( + N, + cp.array(G.src_indices), + cp.array(G.dst_indices), + key_to_id=G.vertex_ids_to_index, ) - # TODO: handle cugraph.Graph - raise TypeError + end_time = time.time() + + print("ANTHONY: from_coo took:", end_time - start_time) + + return rv + +except ModuleNotFoundError as e: + print(f"ANTHONY: {e}") + + def _to_nxcg_graph( + G, + edge_attr: AttrKey | None = None, + edge_default: EdgeValue | None = 1, + edge_dtype: Dtype | None = None, + as_directed: bool = False, + ) -> nxadb.Graph: + m = "nx-cugraph is not installed; cannot convert to nx-cugraph graph" + raise NotImplementedError(m) diff --git a/nx_arangodb/utils/misc.py b/nx_arangodb/utils/misc.py index 105228fd..55e339d3 100644 --- a/nx_arangodb/utils/misc.py +++ b/nx_arangodb/utils/misc.py @@ -31,6 +31,7 @@ def pairwise(it): "index_dtype", "_seed_to_int", "_get_int_dtype", + "_dtype_param", ] # This may switch to np.uint32 at some point diff --git a/tests/test.py b/tests/test.py index cc7692f3..f21b176b 100644 --- a/tests/test.py +++ b/tests/test.py @@ -14,9 +14,34 @@ def test_bc(): G_2 = nxadb.Graph(G_1) - bc_1 = nx.betweenness_centrality(G_1) - bc_2 = nx.betweenness_centrality(G_2) - bc_3 = nx.betweenness_centrality(G_1, backend="arangodb") - bc_4 = nx.betweenness_centrality(G_2, backend="arangodb") + r_1 = nx.betweenness_centrality(G_1) + r_2 = nx.betweenness_centrality(G_2) + r_3 = nx.betweenness_centrality(G_1, backend="arangodb") + r_4 = nx.betweenness_centrality(G_2, backend="arangodb") - assert bc_1 == bc_2 == bc_3 == bc_4 \ No newline at end of file + assert r_1 and r_2 and r_3 and r_4 + +def test_pagerank(): + G_1 = nx.karate_club_graph() + + G_2 = nxadb.Graph(G_1) + + r_1 = nx.pagerank(G_1) + r_2 = nx.pagerank(G_2) + r_3 = nx.pagerank(G_1, backend="arangodb") + r_4 = nx.pagerank(G_2, backend="arangodb") + + assert r_1 and r_2 and r_3 and r_4 + + +def test_louvain(): + G_1 = nx.karate_club_graph() + + G_2 = nxadb.Graph(G_1) + + r_1 = nx.community.louvain_communities(G_1) + r_2 = nx.community.louvain_communities(G_2) + r_3 = nx.community.louvain_communities(G_1, backend="arangodb") + r_4 = nx.community.louvain_communities(G_2, backend="arangodb") + + assert r_1 and r_2 and r_3 and r_4