From ede0b82065dfa5fadb27c8c9f8d41943b030533e Mon Sep 17 00:00:00 2001 From: Anthony Mahanna <43019056+aMahanna@users.noreply.github.com> Date: Mon, 2 Sep 2024 11:33:37 -0400 Subject: [PATCH] misc cleanup (#54) * misc cleanup * fix: typo * fix: `test_shortest_path` --- README.md | 23 +- _nx_arangodb/__init__.py | 5 +- .../algorithms/shortest_paths/generic.py | 2 +- nx_arangodb/classes/dict/node.py | 17 +- nx_arangodb/classes/digraph.py | 9 + nx_arangodb/classes/function.py | 18 ++ nx_arangodb/classes/graph.py | 206 +++++++++--------- nx_arangodb/classes/multidigraph.py | 1 - nx_arangodb/convert.py | 9 +- nx_arangodb/interface.py | 3 +- pyproject.toml | 7 +- run_nx_tests.sh | 2 +- tests/test.py | 28 +-- 13 files changed, 160 insertions(+), 170 deletions(-) diff --git a/README.md b/README.md index 43fe9f6b..211b1f09 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ # nx-arangodb - -
@@ -55,7 +50,7 @@ Benefits of having ArangoDB as a backend to NetworkX include: ## Does this replace NetworkX? -No. This is a plugin to NetworkX, which means that you can use NetworkX as you normally would, but with the added benefit of persisting your graphs to a database. +Not really. This is a plugin to NetworkX, which means that you can use NetworkX as you normally would, but with the added benefit of persisting your graphs to a database. ```python import os @@ -111,7 +106,7 @@ pip install nx-cugraph-cu12 --extra-index-url https://pypi.nvidia.com pip install nx-arangodb ``` -## What are the easiests ways to set up ArangoDB? +## How can I set up ArangoDB? **1) Local Instance via Docker** @@ -149,7 +144,7 @@ os.environ["DATABASE_NAME"] = credentials["dbName"] # ... ``` -## How does Algorithm Dispatching work? +## How does algorithm dispatching work? `nx-arangodb` will automatically dispatch algorithm calls to either CPU or GPU based on if `nx-cugraph` is installed. We rely on a rust-based library called [phenolrs](https://github.com/arangoml/phenolrs) to retrieve ArangoDB Graphs as fast as possible. diff --git a/_nx_arangodb/__init__.py b/_nx_arangodb/__init__.py index 0498f571..616e961b 100644 --- a/_nx_arangodb/__init__.py +++ b/_nx_arangodb/__init__.py @@ -26,8 +26,8 @@ "project": "nx-arangodb", "package": "nx_arangodb", "url": "https://github.com/arangodb/nx-arangodb", - "short_summary": "Remote storage backend.", - # "description": "TODO", + "short_summary": "ArangoDB storage backend to NetworkX.", + "description": "Persist, maintain, and reload NetworkX graphs with ArangoDB.", "functions": { # BEGIN: functions "shortest_path", @@ -81,7 +81,6 @@ def get_info(): "db_name": None, "read_parallelism": None, "read_batch_size": None, - "write_batch_size": None, "use_gpu": True, } diff --git a/nx_arangodb/algorithms/shortest_paths/generic.py b/nx_arangodb/algorithms/shortest_paths/generic.py index f5a9025b..7328b257 100644 --- a/nx_arangodb/algorithms/shortest_paths/generic.py +++ b/nx_arangodb/algorithms/shortest_paths/generic.py @@ -54,7 +54,7 @@ def shortest_path( "weight": weight, } - result = list(G.aql(query, bind_vars=bind_vars)) + result = list(G.query(query, bind_vars=bind_vars)) if not result: raise nx.NodeNotFound(f"Either source {source} or target {target} is not in G") diff --git a/nx_arangodb/classes/dict/node.py b/nx_arangodb/classes/dict/node.py index f41b1666..e55c5171 100644 --- a/nx_arangodb/classes/dict/node.py +++ b/nx_arangodb/classes/dict/node.py @@ -19,6 +19,7 @@ doc_delete, doc_insert, doc_update, + edges_delete, get_arangodb_graph, get_node_id, get_node_type_and_id, @@ -303,21 +304,7 @@ def __delitem__(self, key: str) -> None: if not self.graph.has_vertex(node_id): raise KeyError(key) - # TODO: wrap in edges_delete() method - remove_statements = "\n".join( - f"REMOVE e IN `{edge_def['edge_collection']}` OPTIONS {{ignoreErrors: true}}" # noqa - for edge_def in self.graph.edge_definitions() - ) - - query = f""" - FOR v, e IN 1..1 ANY @src_node_id GRAPH @graph_name - {remove_statements} - """ - - bind_vars = {"src_node_id": node_id, "graph_name": self.graph.name} - - aql(self.db, query, bind_vars) - ##### + edges_delete(self.db, self.graph, node_id) doc_delete(self.db, node_id) diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index e2bea65c..ccf7d65f 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -66,6 +66,13 @@ def __init__( self.remove_node = self.remove_node_override self.reverse = self.reverse_override + assert isinstance(self._succ, AdjListOuterDict) + assert isinstance(self._pred, AdjListOuterDict) + self._succ.mirror = self._pred + self._pred.mirror = self._succ + self._succ.traversal_direction = TraversalDirection.OUTBOUND + self._pred.traversal_direction = TraversalDirection.INBOUND + if ( not self.is_multigraph() and incoming_graph_data is not None @@ -78,6 +85,8 @@ def __init__( ####################### # TODO? + # If we want to continue with "Experimental Views" we need to implement the + # InEdgeView and OutEdgeView classes. # @cached_property # def in_edges(self): # pass diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 19f09324..c9b73822 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -578,6 +578,24 @@ def doc_delete(db: StandardDatabase, id: str, **kwargs: Any) -> None: db.delete_document(id, silent=True, **kwargs) +def edges_delete( + db: StandardDatabase, graph: Graph, src_node_id: str, **kwargs: Any +) -> None: + remove_statements = "\n".join( + f"REMOVE e IN `{edge_def['edge_collection']}` OPTIONS {{ignoreErrors: true}}" # noqa + for edge_def in graph.edge_definitions() + ) + + query = f""" + FOR v, e IN 1..1 ANY @src_node_id GRAPH @graph_name + {remove_statements} + """ + + bind_vars = {"src_node_id": src_node_id, "graph_name": graph.name} + + aql(db, query, bind_vars) + + def doc_insert( db: StandardDatabase, collection: str, diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 53e7abd6..2bb23831 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -27,8 +27,6 @@ node_attr_dict_factory, node_dict_factory, ) -from .dict.adj import AdjListOuterDict -from .enum import TraversalDirection from .function import get_node_id from .reportviews import CustomEdgeView, CustomNodeView @@ -76,27 +74,21 @@ def __init__( *args: Any, **kwargs: Any, ): - self._db = None + self.__db = None self.__name = None self.__use_experimental_views = use_experimental_views + self.__graph_exists_in_db = False - self._graph_exists_in_db = False - self._loaded_incoming_graph_data = False - - self._set_db(db) - if self._db is not None: - self._set_graph_name(name) + self.__set_db(db) + if self.__db is not None: + self.__set_graph_name(name) - self.read_parallelism = read_parallelism - self.read_batch_size = read_batch_size - self.write_batch_size = write_batch_size - - self._set_edge_collections_attributes_to_fetch(edge_collections_attributes) + self.__set_edge_collections_attributes(edge_collections_attributes) # NOTE: Need to revisit these... # self.maintain_node_dict_cache = False # self.maintain_adj_dict_cache = False - self.use_nx_cache = True + # self.use_nx_cache = True self.use_nxcg_cache = True self.nxcg_graph = None @@ -111,7 +103,9 @@ def __init__( # m = "Must set **graph_name** if passing **incoming_graph_data**" # raise ValueError(m) - if self._graph_exists_in_db: + self._loaded_incoming_graph_data = False + + if self.__graph_exists_in_db: if incoming_graph_data is not None: m = "Cannot pass both **incoming_graph_data** and **name** yet if the already graph exists" # noqa: E501 raise NotImplementedError(m) @@ -152,7 +146,7 @@ def edge_type_func(u: str, v: str) -> str: self.default_node_type = default_node_type self._set_factory_methods() - self._set_arangodb_backend_config() + self.__set_arangodb_backend_config(read_parallelism, read_batch_size) elif self.__name: @@ -181,7 +175,7 @@ def edge_type_func(u: str, v: str) -> str: self.__name, incoming_graph_data, edge_definitions=edge_definitions, - batch_size=self.write_batch_size, + batch_size=write_batch_size, use_async=write_async, ) @@ -194,23 +188,15 @@ def edge_type_func(u: str, v: str) -> str: ) self._set_factory_methods() - self._set_arangodb_backend_config() + self.__set_arangodb_backend_config(read_parallelism, read_batch_size) logger.info(f"Graph '{name}' created.") - self._graph_exists_in_db = True + self.__graph_exists_in_db = True if self.__name is not None: kwargs["name"] = self.__name super().__init__(*args, **kwargs) - if self.is_directed() and self.graph_exists_in_db: - assert isinstance(self._succ, AdjListOuterDict) - assert isinstance(self._pred, AdjListOuterDict) - self._succ.mirror = self._pred - self._pred.mirror = self._succ - self._succ.traversal_direction = TraversalDirection.OUTBOUND - self._pred.traversal_direction = TraversalDirection.INBOUND - if self.graph_exists_in_db: self.copy = self.copy_override self.subgraph = self.subgraph_override @@ -220,6 +206,11 @@ def edge_type_func(u: str, v: str) -> str: self.number_of_edges = self.number_of_edges_override self.nbunch_iter = self.nbunch_iter_override + # If incoming_graph_data wasn't loaded by the NetworkX Adapter, + # then we can rely on the CRUD operations of the modified dictionaries + # to load the data into the graph. However, if the graph is directed + # or multigraph, then we leave that responsibility to the child classes + # due to the possibility of additional CRUD-based method overrides. if ( not self.is_directed() and not self.is_multigraph() @@ -232,21 +223,6 @@ def edge_type_func(u: str, v: str) -> str: # Init helper methods # ####################### - def _set_arangodb_backend_config(self) -> None: - if not all([self._host, self._username, self._password, self._db_name]): - m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501 - raise OSError(m) - - config = nx.config.backends.arangodb - config.host = self._host - config.username = self._username - config.password = self._password - config.db_name = self._db_name - config.read_parallelism = self.read_parallelism - config.read_batch_size = self.read_batch_size - config.write_batch_size = self.write_batch_size - config.use_gpu = True # Only used by default if nx-cugraph is available - def _set_factory_methods(self) -> None: """Set the factory methods for the graph, _node, and _adj dictionaries. @@ -281,58 +257,33 @@ def _set_factory_methods(self) -> None: *adj_args, self.symmetrize_edges ) - def _set_edge_collections_attributes_to_fetch( - self, attributes: set[str] | None + def __set_arangodb_backend_config( + self, read_parallelism: int, read_batch_size: int ) -> None: - if attributes is None: - self._edge_collections_attributes = set() - return - if len(attributes) > 0: - self._edge_collections_attributes = attributes - if "_id" not in attributes: - self._edge_collections_attributes.add("_id") - - ########### - # Getters # - ########### - - @property - def db(self) -> StandardDatabase: - if self._db is None: - raise DatabaseNotSet("Database not set") - - return self._db - - @property - def name(self) -> str: - if self.__name is None: - raise GraphNameNotSet("Graph name not set") - - return self.__name - - @name.setter - def name(self, s): - if self.__name is not None: - raise ValueError("Existing graph cannot be renamed") + if not all([self._host, self._username, self._password, self._db_name]): + m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501 + raise OSError(m) - self.__name = s - m = "Note that setting the graph name does not create the graph in the database" # noqa: E501 - logger.warning(m) - nx._clear_cache(self) + config = nx.config.backends.arangodb + config.host = self._host + config.username = self._username + config.password = self._password + config.db_name = self._db_name + config.read_parallelism = read_parallelism + config.read_batch_size = read_batch_size + config.use_gpu = True # Only used by default if nx-cugraph is available - @property - def graph_exists_in_db(self) -> bool: - return self._graph_exists_in_db + def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None: + if not attributes: + self._edge_collections_attributes = set() + return - @property - def get_edge_attributes(self) -> set[str]: - return self._edge_collections_attributes + self._edge_collections_attributes = attributes - ########### - # Setters # - ########### + if "_id" not in attributes: + self._edge_collections_attributes.add("_id") - def _set_db(self, db: StandardDatabase | None = None) -> None: + def __set_db(self, db: Any = None) -> None: self._host = os.getenv("DATABASE_HOST") self._username = os.getenv("DATABASE_USERNAME") self._password = os.getenv("DATABASE_PASSWORD") @@ -344,27 +295,26 @@ def _set_db(self, db: StandardDatabase | None = None) -> None: raise TypeError(m) db.version() - self._db = db + self.__db = db return - # TODO: Raise a custom exception if any of the environment - # variables are missing. For now, we'll just set db to None. if not all([self._host, self._username, self._password, self._db_name]): - self._db = None - logger.warning("Database environment variables not set") + m = "Database environment variables not set. Can't connect to the database" + logger.warning(m) + self.__db = None return - self._db = ArangoClient(hosts=self._host, request_timeout=None).db( + self.__db = ArangoClient(hosts=self._host, request_timeout=None).db( self._db_name, self._username, self._password, verify=True ) - def _set_graph_name(self, name: str | None = None) -> None: - if self._db is None: + def __set_graph_name(self, name: Any = None) -> None: + if self.__db is None: m = "Cannot set graph name without setting the database first" raise DatabaseNotSet(m) if name is None: - self._graph_exists_in_db = False + self.__graph_exists_in_db = False logger.warning(f"**name** not set for {self.__class__.__name__}") return @@ -372,9 +322,51 @@ def _set_graph_name(self, name: str | None = None) -> None: raise TypeError("**name** must be a string") self.__name = name - self._graph_exists_in_db = self.db.has_graph(name) + self.__graph_exists_in_db = self.db.has_graph(name) + + logger.info(f"Graph '{name}' exists: {self.__graph_exists_in_db}") + + ########### + # Getters # + ########### + + @property + def db(self) -> StandardDatabase: + if self.__db is None: + raise DatabaseNotSet("Database not set") + + return self.__db + + @property + def name(self) -> str: + if self.__name is None: + raise GraphNameNotSet("Graph name not set") + + return self.__name - logger.info(f"Graph '{name}' exists: {self._graph_exists_in_db}") + @name.setter + def name(self, s): + if self.graph_exists_in_db: + raise ValueError("Existing graph cannot be renamed") + + m = "Note that setting the graph name does not create the graph in the database" # noqa: E501 + logger.warning(m) + + self.__name = s + self.graph["name"] = s + nx._clear_cache(self) + + @property + def graph_exists_in_db(self) -> bool: + return self.__graph_exists_in_db + + @property + def edge_attributes(self) -> set[str]: + return self._edge_collections_attributes + + ########### + # Setters # + ########### #################### # ArangoDB Methods # @@ -383,7 +375,9 @@ def _set_graph_name(self, name: str | None = None) -> None: def clear_nxcg_cache(self): self.nxcg_graph = None - def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor: + def query( + self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any + ) -> Cursor: return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs) # def pull(self) -> None: @@ -399,7 +393,7 @@ def chat( m = "LLM dependencies not installed. Install with **pip install nx-arangodb[llm]**" # noqa: E501 raise ModuleNotFoundError(m) - if not self._graph_exists_in_db: + if not self.__graph_exists_in_db: m = "Cannot chat without a graph in the database" raise GraphNameNotSet(m) @@ -440,7 +434,7 @@ def adj(self): def edges(self): if self.__use_experimental_views and self.graph_exists_in_db: if self.is_directed(): - logger.warning("CustomEdgeView for Directed Graphs not yet implemented") + logger.warning("CustomEdgeView for DiGraphs not yet implemented") return super().edges if self.is_multigraph(): @@ -463,7 +457,11 @@ def copy_override(self, *args, **kwargs): return G def subgraph_override(self, nbunch): - raise NotImplementedError("Subgraphing is not yet implemented") + if self.graph_exists_in_db: + m = "Subgraphing an ArangoDB Graph is not yet implemented" + raise NotImplementedError(m) + + return super().subgraph(nbunch) def clear_override(self): logger.info("Note that clearing only erases the local cache") diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py index d9a57f9e..fe25eb93 100644 --- a/nx_arangodb/classes/multidigraph.py +++ b/nx_arangodb/classes/multidigraph.py @@ -7,7 +7,6 @@ import nx_arangodb as nxadb from nx_arangodb.classes.digraph import DiGraph from nx_arangodb.classes.multigraph import MultiGraph -from nx_arangodb.logger import logger networkx_api = nxadb.utils.decorators.networkx_class(nx.MultiDiGraph) # type: ignore diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 16a724d0..17458b90 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -13,7 +13,6 @@ try: import cupy as cp - import numpy as np import nx_cugraph as nxcg GPU_AVAILABLE = True @@ -127,9 +126,9 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph: load_node_dict=True, load_adj_dict=True, load_coo=False, - edge_collections_attributes=G.get_edge_attributes, + edge_collections_attributes=G.edge_attributes, load_all_vertex_attributes=False, - load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes), + load_all_edge_attributes=do_load_all_edge_attributes(G.edge_attributes), is_directed=G.is_directed(), is_multigraph=G.is_multigraph(), symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False, @@ -185,9 +184,9 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: load_node_dict=False, load_adj_dict=False, load_coo=True, - edge_collections_attributes=G.get_edge_attributes, + edge_collections_attributes=G.edge_attributes, load_all_vertex_attributes=False, # not used - load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes), + load_all_edge_attributes=do_load_all_edge_attributes(G.edge_attributes), is_directed=G.is_directed(), is_multigraph=G.is_multigraph(), symmetrize_edges_if_directed=( diff --git a/nx_arangodb/interface.py b/nx_arangodb/interface.py index 725c110d..47048752 100644 --- a/nx_arangodb/interface.py +++ b/nx_arangodb/interface.py @@ -62,7 +62,6 @@ def _auto_func(func_name: str, /, *args: Any, **kwargs: Any) -> Any: """ dfunc = _registered_algorithms[func_name] - # TODO: Use `nx.config.backends.arangodb.backend_priority` instead backend_priority = [] if nxadb.convert.GPU_AVAILABLE and nx.config.backends.arangodb.use_gpu: backend_priority.append("cugraph") @@ -143,6 +142,8 @@ def _run_with_backend( # TODO: Convert to nxadb.Graph? # What would this look like? Create a new graph in ArangoDB? # Or just establish a remote connection? + # For now, if dfunc._returns_graph is True, it will return a + # regular nx.Graph object. # if dfunc._returns_graph: # raise NotImplementedError("Returning Graphs not implemented yet") diff --git a/pyproject.toml b/pyproject.toml index 6f930e30..7d181d32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,6 @@ authors = [ license = { text = "Apache 2.0" } requires-python = ">=3.10" classifiers = [ - "Development Status :: 3 - Alpha", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3", @@ -30,7 +29,6 @@ classifiers = [ ] dependencies = [ "networkx>=3.0,<=3.3", - "numpy>=1.23,<2.0a0", "phenolrs", "python-arango", "adbnx-adapter" @@ -55,9 +53,6 @@ dev = [ "sphinx", "sphinx_rtd_theme", ] -gpu = [ - "nx-cugraph-cu12 @ https://pypi.nvidia.com" -] llm = [ "langchain~=0.2.14", "langchain-openai~=0.1.22", @@ -65,7 +60,7 @@ llm = [ ] [project.urls] -Homepage = "https://github.com/aMahanna/nx-arangodb" +Homepage = "https://github.com/arangodb/nx-arangodb" # "plugin" used in nx version < 3.2 [project.entry-points."networkx.plugins"] diff --git a/run_nx_tests.sh b/run_nx_tests.sh index 977e991b..6d0c499f 100755 --- a/run_nx_tests.sh +++ b/run_nx_tests.sh @@ -10,7 +10,7 @@ NETWORKX_FALLBACK_TO_NX=True \ --cov-report= \ "$@" coverage report \ - --include="*/nx_arangodb/algorithms/*" \ + --include="*/nx_arangodb/classes/*" \ --omit=__init__.py \ --show-missing \ --rcfile=$(dirname $0)/pyproject.toml diff --git a/tests/test.py b/tests/test.py index 79277180..58ea73f8 100644 --- a/tests/test.py +++ b/tests/test.py @@ -65,23 +65,6 @@ def assert_pagerank( assert_same_dict_values(d1, d2, digit) -def assert_louvain(l1: list[set[Any]], l2: list[set[Any]]) -> None: - # TODO: Implement some kind of comparison - # Reason: Louvain returns different results on different runs - assert l1 - assert l2 - pass - - -def assert_k_components( - d1: dict[int, list[set[Any]]], d2: dict[int, list[set[Any]]] -) -> None: - assert d1 - assert d2 - assert d1.keys() == d2.keys(), "Dictionaries have different keys" - assert d1 == d2 - - def test_db(load_karate_graph: Any) -> None: assert db.version() @@ -312,7 +295,7 @@ def assert_symmetry_differences( assert_func(r_13_orig, r_9_orig) -def test_shortest_path_remote_algorithm(load_karate_graph: Any) -> None: +def test_shortest_path(load_karate_graph: Any) -> None: G_1 = nxadb.Graph(name="KarateGraph") G_2 = nxadb.DiGraph(name="KarateGraph") @@ -321,8 +304,15 @@ def test_shortest_path_remote_algorithm(load_karate_graph: Any) -> None: r_3 = nx.shortest_path(G_2, source="person/0", target="person/33") r_4 = nx.shortest_path(G_2, source="person/0", target="person/33", weight="weight") + r_5 = nx.shortest_path.orig_func( + G_1, source="person/0", target="person/33", weight="weight" + ) + r_6 = nx.shortest_path.orig_func( + G_2, source="person/0", target="person/33", weight="weight" + ) + assert r_1 == r_3 - assert r_2 == r_4 + assert r_2 == r_4 == r_5 == r_6 assert r_1 != r_2 assert r_3 != r_4