diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 9477c60..e9d5f6a 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -194,6 +194,7 @@ def __init__( and not self._loaded_incoming_graph_data ): nx.convert.to_networkx_graph(incoming_graph_data, create_using=self) + self._loaded_incoming_graph_data = True ####################### # nx.DiGraph Overides # @@ -241,12 +242,25 @@ def add_node_override(self, node_for_adding, **attr): # attr_dict.update(attr) # New: + + node_attr_dict = self.node_attr_dict_factory() + + if self.is_smart: + if self.smart_field not in attr: + m = f"Node {node_for_adding} missing smart field '{self.smart_field}'" # noqa: E501 + raise KeyError(m) + + node_attr_dict.data[self.smart_field] = attr[self.smart_field] + self._node[node_for_adding] = self.node_attr_dict_factory() self._node[node_for_adding].update(attr) # Reason: # Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set # i.e trying to update a node's attributes before we know _which_ node it is + # Furthermore, support for ArangoDB Smart Graphs requires the smart field + # to be set before adding the node to the graph. This is because the smart + # field is used to generate the node's key. ########################### diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 9cadb06..96136df 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -3,7 +3,8 @@ from typing import Any, Callable, ClassVar import networkx as nx -from adbnx_adapter import ADBNX_Adapter +from adbnx_adapter import ADBNX_Adapter, ADBNX_Controller +from adbnx_adapter.typings import NxData, NxId from arango import ArangoClient from arango.cursor import Cursor from arango.database import StandardDatabase @@ -186,19 +187,18 @@ def __init__( write_async: bool = True, symmetrize_edges: bool = False, use_arango_views: bool = False, + overwrite_graph: bool = False, *args: Any, **kwargs: Any, ): self.__db = None - self.__name = None self.__use_arango_views = use_arango_views self.__graph_exists_in_db = False self.__set_db(db) - if self.__db is not None: - self.__set_graph_name(name) - - self.__set_edge_collections_attributes(edge_collections_attributes) + if all([self.__db, name]): + self.__set_graph(name, default_node_type, edge_type_func) + self.__set_edge_collections_attributes(edge_collections_attributes) # NOTE: Need to revisit these... # self.maintain_node_dict_cache = False @@ -219,96 +219,25 @@ def __init__( # raise ValueError(m) self._loaded_incoming_graph_data = False - - if self.__graph_exists_in_db: - if incoming_graph_data is not None: - m = "Cannot pass both **incoming_graph_data** and **name** yet if the already graph exists" # noqa: E501 - raise NotImplementedError(m) - - if edge_type_func is not None: - m = "Cannot pass **edge_type_func** if the graph already exists" - raise NotImplementedError(m) - - self.adb_graph = self.db.graph(self.__name) - vertex_collections = self.adb_graph.vertex_collections() - edge_definitions = self.adb_graph.edge_definitions() - - if default_node_type is None: - default_node_type = list(vertex_collections)[0] - logger.info(f"Default node type set to '{default_node_type}'") - elif default_node_type not in vertex_collections: - m = f"Default node type '{default_node_type}' not found in graph '{name}'" # noqa: E501 - raise InvalidDefaultNodeType(m) - - node_types_to_edge_type_map: dict[tuple[str, str], str] = {} - for e_d in edge_definitions: - for f in e_d["from_vertex_collections"]: - for t in e_d["to_vertex_collections"]: - if (f, t) in node_types_to_edge_type_map: - # TODO: Should we log a warning at least? - continue - - node_types_to_edge_type_map[(f, t)] = e_d["edge_collection"] - - def edge_type_func(u: str, v: str) -> str: - try: - return node_types_to_edge_type_map[(u, v)] - except KeyError: - m = f"Edge type ambiguity between '{u}' and '{v}'" - raise EdgeTypeAmbiguity(m) - - self.edge_type_func = edge_type_func - self.default_node_type = default_node_type - + if self.graph_exists_in_db: self._set_factory_methods() self.__set_arangodb_backend_config(read_parallelism, read_batch_size) - elif self.__name: + if overwrite_graph: + logger.info("Truncating graph collections...") - prefix = f"{name}_" if name else "" - if default_node_type is None: - default_node_type = f"{prefix}node" - if edge_type_func is None: - edge_type_func = lambda u, v: f"{u}_to_{v}" # noqa: E731 + for col in self.adb_graph.vertex_collections(): + self.db.collection(col).truncate() - self.edge_type_func = edge_type_func - self.default_node_type = default_node_type - - # TODO: Parameterize the edge definitions - # How can we work with a heterogenous **incoming_graph_data**? - default_edge_type = edge_type_func(default_node_type, default_node_type) - edge_definitions = [ - { - "edge_collection": default_edge_type, - "from_vertex_collections": [default_node_type], - "to_vertex_collections": [default_node_type], - } - ] + for col in self.adb_graph.edge_definitions(): + self.db.collection(col["edge_collection"]).truncate() if isinstance(incoming_graph_data, nx.Graph): - self.adb_graph = ADBNX_Adapter(self.db).networkx_to_arangodb( - self.__name, - incoming_graph_data, - edge_definitions=edge_definitions, - batch_size=write_batch_size, - use_async=write_async, - ) - + self._load_nx_graph(incoming_graph_data, write_batch_size, write_async) self._loaded_incoming_graph_data = True - else: - self.adb_graph = self.db.create_graph( - self.__name, - edge_definitions=edge_definitions, - ) - - self._set_factory_methods() - self.__set_arangodb_backend_config(read_parallelism, read_batch_size) - logger.info(f"Graph '{name}' created.") - self.__graph_exists_in_db = True - - if self.__name is not None: - kwargs["name"] = self.__name + if name is not None: + kwargs["name"] = name super().__init__(*args, **kwargs) @@ -333,6 +262,7 @@ def edge_type_func(u: str, v: str) -> str: and not self._loaded_incoming_graph_data ): nx.convert.to_networkx_graph(incoming_graph_data, create_using=self) + self._loaded_incoming_graph_data = True ####################### # Init helper methods # @@ -423,23 +353,118 @@ def __set_db(self, db: Any = None) -> None: self._db_name, self._username, self._password, verify=True ) - def __set_graph_name(self, name: Any = None) -> None: - if self.__db is None: - m = "Cannot set graph name without setting the database first" - raise DatabaseNotSet(m) - - if not name: - self.__graph_exists_in_db = False - logger.warning(f"**name** not set for {self.__class__.__name__}") - return - + def __set_graph( + self, + name: Any, + default_node_type: str | None = None, + edge_type_func: Callable[[str, str], str] | None = None, + ) -> None: if not isinstance(name, str): raise TypeError("**name** must be a string") + if self.db.has_graph(name): + logger.info(f"Graph '{name}' exists.") + + if edge_type_func is not None: + m = "Cannot pass **edge_type_func** if the graph already exists" + raise NotImplementedError(m) + + self.adb_graph = self.db.graph(name) + vertex_collections = self.adb_graph.vertex_collections() + edge_definitions = self.adb_graph.edge_definitions() + + if default_node_type is None: + default_node_type = list(vertex_collections)[0] + logger.info(f"Default node type set to '{default_node_type}'") + + elif default_node_type not in vertex_collections: + m = f"Default node type '{default_node_type}' not found in graph '{name}'" # noqa: E501 + raise InvalidDefaultNodeType(m) + + node_types_to_edge_type_map: dict[tuple[str, str], str] = {} + for e_d in edge_definitions: + for f in e_d["from_vertex_collections"]: + for t in e_d["to_vertex_collections"]: + if (f, t) in node_types_to_edge_type_map: + # TODO: Should we log a warning at least? + continue + + node_types_to_edge_type_map[(f, t)] = e_d["edge_collection"] + + def edge_type_func(u: str, v: str) -> str: + try: + return node_types_to_edge_type_map[(u, v)] + except KeyError: + m = f"Edge type ambiguity between '{u}' and '{v}'" + raise EdgeTypeAmbiguity(m) + + else: + prefix = f"{name}_" if name else "" + + if default_node_type is None: + default_node_type = f"{prefix}node" + + if edge_type_func is None: + edge_type_func = lambda u, v: f"{u}_to_{v}" # noqa: E731 + + # TODO: Parameterize the edge definitions + # How can we work with a heterogenous **incoming_graph_data**? + default_edge_type = edge_type_func(default_node_type, default_node_type) + edge_definitions = [ + { + "edge_collection": default_edge_type, + "from_vertex_collections": [default_node_type], + "to_vertex_collections": [default_node_type], + } + ] + + # Create a general graph if it doesn't exist + self.adb_graph = self.db.create_graph( + name=name, + edge_definitions=edge_definitions, + ) + + logger.info(f"Graph '{name}' created.") + self.__name = name - self.__graph_exists_in_db = self.db.has_graph(name) + self.__graph_exists_in_db = True + self.edge_type_func = edge_type_func + self.default_node_type = default_node_type + + properties = self.adb_graph.properties() + self.__is_smart: bool = properties.get("smart", False) + self.__smart_field: str | None = properties.get("smart_field") + + def _load_nx_graph( + self, nx_graph: nx.Graph, write_batch_size: int, write_async: bool + ) -> None: + controller = ADBNX_Controller + + if all([self.is_smart, self.smart_field]): + smart_field = self.__smart_field - logger.info(f"Graph '{name}' exists: {self.__graph_exists_in_db}") + class SmartController(ADBNX_Controller): + def _keyify_networkx_node( + self, i: int, nx_node_id: NxId, nx_node: NxData, col: str + ) -> str: + if smart_field not in nx_node: + m = f"Node {nx_node_id} missing smart field '{smart_field}'" # noqa: E501 + raise KeyError(m) + + return f"{nx_node[smart_field]}:{str(i)}" + + def _prepare_networkx_edge(self, nx_edge: NxData, col: str) -> None: + del nx_edge["_key"] + + controller = SmartController + logger.info(f"Using smart field '{smart_field}' for node keys") + + ADBNX_Adapter(self.db, controller()).networkx_to_arangodb( + self.adb_graph.name, + nx_graph, + batch_size=write_batch_size, + use_async=write_async, + ) ########### # Getters # @@ -479,6 +504,14 @@ def graph_exists_in_db(self) -> bool: def edge_attributes(self) -> set[str]: return self._edge_collections_attributes + @property + def is_smart(self) -> bool: + return self.__is_smart + + @property + def smart_field(self) -> str | None: + return self.__smart_field + ########### # Setters # ########### @@ -645,12 +678,24 @@ def add_node_override(self, node_for_adding, **attr): # attr_dict.update(attr) # New: - self._node[node_for_adding] = self.node_attr_dict_factory() + node_attr_dict = self.node_attr_dict_factory() + + if self.is_smart: + if self.smart_field not in attr: + m = f"Node {node_for_adding} missing smart field '{self.smart_field}'" # noqa: E501 + raise KeyError(m) + + node_attr_dict.data[self.smart_field] = attr[self.smart_field] + + self._node[node_for_adding] = node_attr_dict self._node[node_for_adding].update(attr) # Reason: # Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set # i.e trying to update a node's attributes before we know _which_ node it is + # Furthermore, support for ArangoDB Smart Graphs requires the smart field + # to be set before adding the node to the graph. This is because the smart + # field is used to generate the node's key. ########################### diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index 07c30b7..1100025 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -215,6 +215,8 @@ def __init__( else: nx.convert.to_networkx_graph(incoming_graph_data, create_using=self) + self._loaded_incoming_graph_data = True + ####################### # Init helper methods # #######################