From 5ed8afa4ca9b76ae19c488fed0a440e39320708c Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Sat, 4 May 2024 00:27:38 -0400 Subject: [PATCH] checkpoint (CI is failing) --- nx_arangodb/classes/dict.py | 87 ++++++++++++++++++------ nx_arangodb/classes/exceptions.py | 17 +++++ nx_arangodb/classes/function.py | 107 ++++++++++++++++++++++++------ nx_arangodb/classes/graph.py | 6 +- 4 files changed, 173 insertions(+), 44 deletions(-) create mode 100644 nx_arangodb/classes/exceptions.py diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index 7e18a89e..62e6c4a7 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -5,6 +5,7 @@ from typing import Any, Callable from arango.database import StandardDatabase +from arango.exceptions import DocumentInsertError from arango.graph import Graph from .function import ( @@ -14,7 +15,9 @@ aql_doc_get_keys, aql_doc_get_length, aql_doc_has_key, + aql_edge_exists, aql_edge_get, + aql_edge_id, aql_single, create_collection, doc_delete, @@ -84,7 +87,7 @@ class GraphDict(UserDict): :type graph_name: str """ - COLLECTION_NAME = "NXADB_GRAPH_ATTRIBUTES" + COLLECTION_NAME = "nxadb_graphs" def __init__(self, db: StandardDatabase, graph_name: str, *args, **kwargs): super().__init__(*args, **kwargs) @@ -134,8 +137,8 @@ def __setitem__(self, key: str, value: Any): @key_is_not_reserved def __delitem__(self, key): """del G.graph['foo']""" - del self.data[key] doc_update(self.db, self.graph_id, {key: None}) + self.data.pop(key, None) @keys_are_strings @keys_are_not_reserved @@ -196,8 +199,8 @@ def __getitem__(self, key: str) -> NodeAttrDict: """G._node['node/1']""" node_id = get_node_id(key, self.default_node_type) - if node_id in self.data: - return self.data[node_id] + if value := self.data.get(node_id): + return value if value := self.graph.vertex(node_id): node_attr_dict: NodeAttrDict = self.node_attr_dict_factory() @@ -234,8 +237,8 @@ def __delitem__(self, key: Any) -> None: """del g._node['node/1']""" node_id = get_node_id(key, self.default_node_type) - del self.data[node_id] doc_delete(self.db, node_id) + self.data.pop(node_id, None) def __len__(self) -> int: """len(g._node)""" @@ -378,8 +381,8 @@ def __setitem__(self, key: str, value: Any): @key_is_not_reserved def __delitem__(self, key: str): """del G._node['node/1']['foo']""" - del self.data[key] doc_update(self.db, self.node_id, {key: None}) + self.data.pop(key, None) def __iter__(self) -> Iterator[str]: """for key in G._node['node/1']""" @@ -485,8 +488,8 @@ def __getitem__(self, key) -> AdjListInnerDict: """G.adj["node/1"]""" node_type, node_id = get_node_type_and_id(key, self.default_node_type) - if node_id in self.data: - return self.data[node_id] + if value := self.data.get(node_id): + return value if self.graph.has_vertex(node_id): adjlist_inner_dict: AdjListInnerDict = self.adjlist_inner_dict_factory() @@ -520,12 +523,12 @@ def __setitem__(self, src_key: str, adjlist_inner_dict: AdjListInnerDict): dst_key, self.default_node_type ) - edge_type = edge_dict.get("_edge_type") # pop? + edge_type = edge_dict.get("_edge_type") if edge_type is None: edge_type = self.edge_type_func(src_node_type, dst_node_type) results[dst_key] = self.graph.link( - edge_type, src_node_id, dst_node_id, edge_dict, silent=True + edge_type, src_node_id, dst_node_id, edge_dict ) adjlist_inner_dict.src_node_id = src_node_id @@ -539,7 +542,24 @@ def __delitem__(self, key: Any) -> None: """ del G._adj['node/1'] """ - raise NotImplementedError("AdjListOuterDict.__delitem__()") + node_id = get_node_id(key, self.default_node_type) + + if not self.graph.has_vertex(node_id): + return + + remove_statements = "\n".join( + f"REMOVE e IN `{edge_def['edge_collection']}` OPTIONS {{ignoreErrors: true}}" + for edge_def in self.graph.edge_definitions() + ) + + query = f""" + FOR v, e IN 1..1 OUTBOUND @src_node_id GRAPH @graph_name + {remove_statements} + """ + + bind_vars = {"src_node_id": node_id, "graph_name": self.graph.name} + + aql(self.db, query, bind_vars) def __len__(self) -> int: """len(g._adj)""" @@ -584,7 +604,7 @@ def update(self, edges: dict[str, dict[str, dict[str, Any]]]): for dst_key, edge_dict in dst_dict.items(): dst_node_type, dst_node_id = get_node_type_and_id(dst_key) - edge_type = edge_dict.get("_edge_type") # pop? + edge_type = edge_dict.get("_edge_type") if edge_type is None: edge_type = self.edge_type_func(src_node_type, dst_node_type) @@ -693,17 +713,14 @@ def __contains__(self, key) -> bool: if dst_node_id in self.data: return True - return aql_edge_get( + return aql_edge_exists( self.db, self.src_node_id, dst_node_id, self.graph.name, direction="OUTBOUND", - return_bool=True, ) - # CHECKPOINT... - @key_is_string def __getitem__(self, key) -> EdgeAttrDict: """g._adj['node/1']['node/2']""" @@ -718,7 +735,6 @@ def __getitem__(self, key) -> EdgeAttrDict: dst_node_id, self.graph.name, direction="OUTBOUND", - return_bool=False, ) if not edge: @@ -739,11 +755,24 @@ def __setitem__(self, key: str, value: dict | EdgeAttrDict): dst_node_type, dst_node_id = get_node_type_and_id(key, self.default_node_type) - edge_type = value.data.get("_edge_type") # pop? + edge_type = value.data.get("_edge_type") if edge_type is None: edge_type = self.edge_type_func(self.src_node_type, dst_node_type) data = value.data + + if edge_id := value.edge_id: + self.graph.delete_edge(edge_id) + + elif edge_id := aql_edge_id( + self.db, + self.src_node_id, + dst_node_id, + self.graph.name, + direction="OUTBOUND", + ): + self.graph.delete_edge(edge_id) + edge = self.graph.link(edge_type, self.src_node_id, dst_node_id, data) edge_attr_dict = self.edge_attr_dict_factory() @@ -755,7 +784,21 @@ def __setitem__(self, key: str, value: dict | EdgeAttrDict): @key_is_string def __delitem__(self, key: Any) -> None: """del g._adj['node/1']['node/2']""" - raise NotImplementedError("AdjListInnerDict.__delitem__()") + dst_node_id = get_node_id(key, self.default_node_type) + + edge_id = aql_edge_id( + self.db, + self.src_node_id, + dst_node_id, + self.graph.name, + direction="OUTBOUND", + ) + + if not edge_id: + return + + self.graph.delete_edge(edge_id) + self.data.pop(dst_node_id, None) def __len__(self) -> int: """len(g._adj['node/1'])""" @@ -895,8 +938,8 @@ def __setitem__(self, key: str, value: Any): @key_is_not_reserved def __delitem__(self, key: str): """del G._adj['node/1']['node/2']['foo']""" - del self.data[key] - doc_update(self.db, self.node_id, {key: None}) + doc_update(self.db, self.edge_id, {key: None}) + self.data.pop(key, None) def __iter__(self) -> Iterator[str]: """for key in G._adj['node/1']['node/2']""" @@ -922,7 +965,7 @@ def values(self, cache: bool = True): def items(self, cache: bool = True): """G._adj['node/1']['node/'2].items()""" - doc = self.db.document(self.node_id) + doc = self.db.document(self.edge_id) if cache: self.data = doc diff --git a/nx_arangodb/classes/exceptions.py b/nx_arangodb/classes/exceptions.py new file mode 100644 index 00000000..bd09d6b3 --- /dev/null +++ b/nx_arangodb/classes/exceptions.py @@ -0,0 +1,17 @@ +class NetworkXArangoDBException(Exception): + pass + + +EDGE_ALREADY_EXISTS_ERROR_CODE = 1210 + + +class EdgeAlreadyExists(NetworkXArangoDBException): + """Raised when trying to add an edge that already exists in the graph.""" + + pass + + +class AQLMultipleResultsFound(NetworkXArangoDBException): + """Raised when multiple results are returned from a query that was expected to return a single result.""" + + pass diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 4c02f9d0..5e206fd3 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -2,14 +2,19 @@ from typing import Any, Tuple +import arango import networkx as nx import numpy as np -from arango.collection import StandardCollection -from arango.cursor import Cursor -from arango.database import StandardDatabase +from arango import exceptions, graph import nx_arangodb as nxadb +from .exceptions import ( + EDGE_ALREADY_EXISTS_ERROR_CODE, + AQLMultipleResultsFound, + EdgeAlreadyExists, +) + def get_arangodb_graph( G: nxadb.Graph | nxadb.DiGraph, @@ -120,8 +125,8 @@ def wrapper(self, dict, *args, **kwargs) -> Any: def create_collection( - db: StandardDatabase, collection_name: str, edge: bool = False -) -> StandardCollection: + db: arango.StandardDatabase, collection_name: str, edge: bool = False +) -> arango.StandardCollection: """Creates a collection if it does not exist and returns it.""" if not db.has_collection(collection_name): db.create_collection(collection_name, edge=edge) @@ -130,63 +135,119 @@ def create_collection( def aql( - db: StandardDatabase, query: str, bind_vars: dict[str, Any], **kwargs -) -> Cursor: + db: arango.StandardDatabase, query: str, bind_vars: dict[str, Any], **kwargs +) -> arango.Cursor: """Executes an AQL query and returns the cursor.""" return db.aql.execute(query, bind_vars=bind_vars, stream=True, **kwargs) def aql_as_list( - db: StandardDatabase, query: str, bind_vars: dict[str, Any], **kwargs + db: arango.StandardDatabase, query: str, bind_vars: dict[str, Any], **kwargs ) -> list[Any]: """Executes an AQL query and returns the results as a list.""" return list(aql(db, query, bind_vars, **kwargs)) -def aql_single(db: StandardDatabase, query: str, bind_vars: dict[str, Any]) -> Any: +def aql_single( + db: arango.StandardDatabase, query: str, bind_vars: dict[str, Any] +) -> Any: """Executes an AQL query and returns the first result.""" result = aql_as_list(db, query, bind_vars) if len(result) == 0: return None + if len(result) > 1: + raise AQLMultipleResultsFound(f"Multiple results found: {result}") + return result[0] -def aql_doc_has_key(db: StandardDatabase, id: str, key: str) -> bool: +def aql_doc_has_key(db: arango.StandardDatabase, id: str, key: str) -> bool: """Checks if a document has a key.""" query = f"RETURN HAS(DOCUMENT(@id), @key)" bind_vars = {"id": id, "key": key} return aql_single(db, query, bind_vars) -def aql_doc_get_key(db: StandardDatabase, id: str, key: str) -> Any: +def aql_doc_get_key(db: arango.StandardDatabase, id: str, key: str) -> Any: """Gets a key from a document.""" query = f"RETURN DOCUMENT(@id).@key" bind_vars = {"id": id, "key": key} return aql_single(db, query, bind_vars) -def aql_doc_get_keys(db: StandardDatabase, id: str) -> list[str]: +def aql_doc_get_keys(db: arango.StandardDatabase, id: str) -> list[str]: """Gets the keys of a document.""" query = f"RETURN ATTRIBUTES(DOCUMENT(@id))" bind_vars = {"id": id} return aql_single(db, query, bind_vars) -def aql_doc_get_length(db: StandardDatabase, id: str) -> int: +def aql_doc_get_length(db: arango.StandardDatabase, id: str) -> int: """Gets the length of a document.""" query = f"RETURN LENGTH(DOCUMENT(@id))" bind_vars = {"id": id} return aql_single(db, query, bind_vars) +def aql_edge_exists( + db: arango.StandardDatabase, + src_node_id: str, + dst_node_id: str, + graph_name: str, + direction: str, +): + return aql_edge( + db, + src_node_id, + dst_node_id, + graph_name, + direction, + return_clause="true", + ) + + def aql_edge_get( - db: StandardDatabase, + db: arango.StandardDatabase, src_node_id: str, dst_node_id: str, graph_name: str, direction: str, - return_bool: bool, +): + return aql_edge( + db, + src_node_id, + dst_node_id, + graph_name, + direction, + return_clause="e", + ) + + +def aql_edge_id( + db: arango.StandardDatabase, + src_node_id: str, + dst_node_id: str, + graph_name: str, + direction: str, +): + return aql_edge( + db, + src_node_id, + dst_node_id, + graph_name, + direction, + return_clause="e._id", + ) + + +def aql_edge( + db: arango.StandardDatabase, + src_node_id: str, + dst_node_id: str, + graph_name: str, + direction: str, + return_clause: str, ): if direction == "INBOUND": filter_clause = f"e._from == @dst_node_id" @@ -197,8 +258,6 @@ def aql_edge_get( else: raise ValueError(f"Invalid direction: {direction}") - return_clause = "true" if return_bool else "e" - query = f""" FOR v, e IN 1..1 {direction} @src_node_id GRAPH @graph_name FILTER {filter_clause} @@ -214,25 +273,31 @@ def aql_edge_get( return aql_single(db, query, bind_vars) -def doc_update(db: StandardDatabase, id: str, data: dict[str, Any], **kwargs) -> None: +def doc_update( + db: arango.StandardDatabase, id: str, data: dict[str, Any], **kwargs +) -> None: """Updates a document in the collection.""" db.update_document({**data, "_id": id}, keep_none=False, silent=True, **kwargs) -def doc_delete(db: StandardDatabase, id: str, **kwargs) -> None: +def doc_delete(db: arango.StandardDatabase, id: str, **kwargs) -> None: """Deletes a document from the collection.""" db.delete_document(id, silent=True, **kwargs) def doc_insert( - db: StandardDatabase, collection: str, id: str, data: dict[str, Any] = {}, **kwargs + db: arango.StandardDatabase, + collection: str, + id: str, + data: dict[str, Any] = {}, + **kwargs, ) -> dict[str, Any] | bool: """Inserts a document into a collection.""" return db.insert_document(collection, {**data, "_id": id}, overwrite=True, **kwargs) def doc_get_or_insert( - db: StandardDatabase, collection: str, id: str, **kwargs + db: arango.StandardDatabase, collection: str, id: str, **kwargs ) -> dict[str, Any]: """Loads a document if existing, otherwise inserts it & returns it.""" if db.has_document(id): diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 98a8a6c1..d676e8d9 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -38,7 +38,7 @@ def to_networkx_class(cls) -> type[nx.Graph]: def __init__( self, graph_name: str | None = None, - default_node_type: str = "NXADB_NODES", + default_node_type: str = "nxadb_nodes", edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}", *args, **kwargs, @@ -279,6 +279,10 @@ def query(self, query: str, bind_vars: dict | None = None, **kwargs) -> Cursor: def nodes(self): return CustomNodeView(self) + # @cached_property + # def edges(self): + # return CustomEdgeView(self) + def add_node(self, node_for_adding, **attr): if node_for_adding not in self._node: if node_for_adding is None: