diff --git a/nx_arangodb/classes/dict/node.py b/nx_arangodb/classes/dict/node.py index 0ef179d..872b158 100644 --- a/nx_arangodb/classes/dict/node.py +++ b/nx_arangodb/classes/dict/node.py @@ -221,7 +221,8 @@ def update(self, attrs: Any) -> None: if not attrs: return - self.data.update(build_node_attr_dict_data(self, attrs)) + node_attr_dict_data = build_node_attr_dict_data(self, attrs) + self.data.update(node_attr_dict_data) if not self.node_id: logger.debug("Node ID not set, skipping NodeAttrDict(?).update()") @@ -275,10 +276,12 @@ def __init__( self.FETCHED_ALL_DATA = False self.FETCHED_ALL_IDS = False - def _create_node_attr_dict(self, vertex: dict[str, Any]) -> NodeAttrDict: + def _create_node_attr_dict( + self, node_id: str, node_data: dict[str, Any] + ) -> NodeAttrDict: node_attr_dict = self.node_attr_dict_factory() - node_attr_dict.node_id = vertex["_id"] - node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, vertex) + node_attr_dict.node_id = node_id + node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data) return node_attr_dict @@ -321,8 +324,8 @@ def __getitem__(self, key: str) -> NodeAttrDict: if node_id not in self.data and self.FETCHED_ALL_IDS: raise KeyError(key) - if vertex_db := vertex_get(self.graph, node_id): - node_attr_dict = self._create_node_attr_dict(vertex_db) + if node := vertex_get(self.graph, node_id): + node_attr_dict = self._create_node_attr_dict(node["_id"], node) self.data[node_id] = node_attr_dict return node_attr_dict @@ -331,18 +334,16 @@ def __getitem__(self, key: str) -> NodeAttrDict: @key_is_string def __setitem__(self, key: str, value: NodeAttrDict) -> None: - """G._node['node/1'] = {'foo': 'bar'} - - Not to be confused with: - - G.add_node('node/1', foo='bar') - """ + """G._node['node/1'] = {'foo': 'bar'}""" assert isinstance(value, NodeAttrDict) node_type, node_id = get_node_type_and_id(key, self.default_node_type) result = doc_insert(self.db, node_type, node_id, value.data) - node_attr_dict = self._create_node_attr_dict(result) + node_attr_dict = self._create_node_attr_dict( + result["_id"], {**value.data, **result} + ) self.data[node_id] = node_attr_dict @@ -405,10 +406,7 @@ def copy(self) -> Any: @keys_are_strings def __update_local_nodes(self, nodes: Any) -> None: for node_id, node_data in nodes.items(): - node_attr_dict = self.node_attr_dict_factory() - node_attr_dict.node_id = node_id - node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data) - + node_attr_dict = self._create_node_attr_dict(node_id, node_data) self.data[node_id] = node_attr_dict @keys_are_strings @@ -478,7 +476,7 @@ def _fetch_all(self): for node_id, node_data in node_dict.items(): del node_data["_rev"] # TODO: Optimize away via phenolrs - node_attr_dict = self._create_node_attr_dict(node_data) + node_attr_dict = self._create_node_attr_dict(node_data["_id"], node_data) self.data[node_id] = node_attr_dict self.FETCHED_ALL_DATA = True diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 9477c60..3bfb943 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -60,8 +60,8 @@ class DiGraph(Graph, nx.DiGraph): name : str (optional, default: None) Name of the graph in the database. If the graph already exists, the user can pass the name of the graph to connect to it. If - the graph does not exist, the user can create a new graph by - passing the name. NOTE: Must be used in conjunction with + the graph does not exist, a General Graph will be created by + passing the **name**. NOTE: Must be used in conjunction with **incoming_graph_data** if the user wants to persist the graph in ArangoDB. @@ -125,6 +125,12 @@ class DiGraph(Graph, nx.DiGraph): whenever possible. NOTE: This feature is experimental and may not work as expected. + overwrite_graph : bool (optional, default: False) + Whether to overwrite the graph in the database if it already exists. If + set to True, the graph collections will be dropped and recreated. Note that + this operation is irreversible and will result in the loss of all data in + the graph. NOTE: If set to True, Collection Indexes will also be lost. + args: positional arguments for nx.Graph Additional arguments passed to nx.Graph. @@ -154,6 +160,7 @@ def __init__( write_async: bool = True, symmetrize_edges: bool = False, use_arango_views: bool = False, + overwrite_graph: bool = False, *args: Any, **kwargs: Any, ): @@ -171,6 +178,7 @@ def __init__( write_async, symmetrize_edges, use_arango_views, + overwrite_graph, *args, **kwargs, ) @@ -178,6 +186,7 @@ def __init__( if self.graph_exists_in_db: self.clear_edges = self.clear_edges_override self.add_node = self.add_node_override + self.add_nodes_from = self.add_nodes_from_override self.remove_node = self.remove_node_override self.reverse = self.reverse_override @@ -194,6 +203,7 @@ def __init__( and not self._loaded_incoming_graph_data ): nx.convert.to_networkx_graph(incoming_graph_data, create_using=self) + self._loaded_incoming_graph_data = True ####################### # nx.DiGraph Overides # @@ -225,9 +235,10 @@ def clear_edges_override(self): super().clear_edges() def add_node_override(self, node_for_adding, **attr): + if node_for_adding is None: + raise ValueError("None cannot be a node") + if node_for_adding not in self._succ: - if node_for_adding is None: - raise ValueError("None cannot be a node") self._succ[node_for_adding] = self.adjlist_inner_dict_factory() self._pred[node_for_adding] = self.adjlist_inner_dict_factory() @@ -241,12 +252,15 @@ def add_node_override(self, node_for_adding, **attr): # attr_dict.update(attr) # New: - self._node[node_for_adding] = self.node_attr_dict_factory() - self._node[node_for_adding].update(attr) + node_attr_dict = self.node_attr_dict_factory() + node_attr_dict.data = attr + self._node[node_for_adding] = node_attr_dict # Reason: - # Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set - # i.e trying to update a node's attributes before we know _which_ node it is + # We can optimize the process of adding a node by creating avoiding + # the creation of a new dictionary and updating it with the attributes. + # Instead, we can create a new node_attr_dict object and set the attributes + # directly. This only makes 1 network call to the database instead of 2. ########################### @@ -255,6 +269,49 @@ def add_node_override(self, node_for_adding, **attr): nx._clear_cache(self) + def add_nodes_from_override(self, nodes_for_adding, **attr): + for n in nodes_for_adding: + try: + newnode = n not in self._node + newdict = attr + except TypeError: + n, ndict = n + newnode = n not in self._node + newdict = attr.copy() + newdict.update(ndict) + if newnode: + if n is None: + raise ValueError("None cannot be a node") + self._succ[n] = self.adjlist_inner_dict_factory() + self._pred[n] = self.adjlist_inner_dict_factory() + + ###################### + # NOTE: monkey patch # + ###################### + + # Old: + # self._node[n] = self.node_attr_dict_factory() + # + # self._node[n].update(newdict) + + # New: + node_attr_dict = self.node_attr_dict_factory() + node_attr_dict.data = newdict + self._node[n] = node_attr_dict + + else: + self._node[n].update(newdict) + + # Reason: + # We can optimize the process of adding a node by creating avoiding + # the creation of a new dictionary and updating it with the attributes. + # Instead, we create a new node_attr_dict object and set the attributes + # directly. This only makes 1 network call to the database instead of 2. + + ########################### + + nx._clear_cache(self) + def remove_node_override(self, n): if isinstance(n, (str, int)): n = get_node_id(str(n), self.default_node_type) diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 993db90..491c0cd 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -199,6 +199,17 @@ def to_dict(self): return cls +def cast_to_string(value: Any) -> str: + """Casts a value to a string.""" + if isinstance(value, str): + return value + + if isinstance(value, (int, float)): + return str(value) + + raise TypeError(f"{value} cannot be casted to string.") + + def key_is_string(func: Callable[..., Any]) -> Any: """Decorator to check if the key is a string. Will attempt to cast the key to a string if it is not. @@ -208,12 +219,7 @@ def wrapper(self: Any, key: Any, *args: Any, **kwargs: Any) -> Any: if key is None: raise ValueError("Key cannot be None.") - if not isinstance(key, str): - if not isinstance(key, (int, float)): - raise TypeError(f"{key} cannot be casted to string.") - - key = str(key) - + key = cast_to_string(key) return func(self, key, *args, **kwargs) return wrapper @@ -270,12 +276,7 @@ def wrapper(self: Any, data: Any, *args: Any, **kwargs: Any) -> Any: raise TypeError(f"Decorator found unsupported type: {type(data)}.") for key, value in items: - if not isinstance(key, str): - if not isinstance(key, (int, float)): - raise TypeError(f"{key} cannot be casted to string.") - - key = str(key) - + key = cast_to_string(key) data_dict[key] = value return func(self, data_dict, *args, **kwargs) @@ -655,7 +656,7 @@ def doc_insert( data: dict[str, Any] = {}, **kwargs: Any, ) -> dict[str, Any]: - """Inserts a document into a collection.""" + """Inserts a document into a collection. Returns document metadata.""" result: dict[str, Any] = db.insert_document( collection, {**data, "_id": id}, overwrite=True, **kwargs ) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 9cadb06..336cc3b 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -3,7 +3,8 @@ from typing import Any, Callable, ClassVar import networkx as nx -from adbnx_adapter import ADBNX_Adapter +from adbnx_adapter import ADBNX_Adapter, ADBNX_Controller +from adbnx_adapter.typings import NxData, NxId from arango import ArangoClient from arango.cursor import Cursor from arango.database import StandardDatabase @@ -14,6 +15,7 @@ DatabaseNotSet, EdgeTypeAmbiguity, GraphNameNotSet, + GraphNotEmpty, InvalidDefaultNodeType, ) from nx_arangodb.logger import logger @@ -92,8 +94,8 @@ class Graph(nx.Graph): name : str (optional, default: None) Name of the graph in the database. If the graph already exists, the user can pass the name of the graph to connect to it. If - the graph does not exist, the user can create a new graph by - passing the name. NOTE: Must be used in conjunction with + the graph does not exist, a General Graph will be created by + passing the **name**. NOTE: Must be used in conjunction with **incoming_graph_data** if the user wants to persist the graph in ArangoDB. @@ -157,6 +159,12 @@ class Graph(nx.Graph): whenever possible. NOTE: This feature is experimental and may not work as expected. + overwrite_graph : bool (optional, default: False) + Whether to overwrite the graph in the database if it already exists. If + set to True, the graph collections will be dropped and recreated. Note that + this operation is irreversible and will result in the loss of all data in + the graph. NOTE: If set to True, Collection Indexes will also be lost. + args: positional arguments for nx.Graph Additional arguments passed to nx.Graph. @@ -186,19 +194,18 @@ def __init__( write_async: bool = True, symmetrize_edges: bool = False, use_arango_views: bool = False, + overwrite_graph: bool = False, *args: Any, **kwargs: Any, ): self.__db = None - self.__name = None self.__use_arango_views = use_arango_views self.__graph_exists_in_db = False self.__set_db(db) - if self.__db is not None: - self.__set_graph_name(name) - - self.__set_edge_collections_attributes(edge_collections_attributes) + if all([self.__db, name]): + self.__set_graph(name, default_node_type, edge_type_func) + self.__set_edge_collections_attributes(edge_collections_attributes) # NOTE: Need to revisit these... # self.maintain_node_dict_cache = False @@ -219,96 +226,33 @@ def __init__( # raise ValueError(m) self._loaded_incoming_graph_data = False - - if self.__graph_exists_in_db: - if incoming_graph_data is not None: - m = "Cannot pass both **incoming_graph_data** and **name** yet if the already graph exists" # noqa: E501 - raise NotImplementedError(m) - - if edge_type_func is not None: - m = "Cannot pass **edge_type_func** if the graph already exists" - raise NotImplementedError(m) - - self.adb_graph = self.db.graph(self.__name) - vertex_collections = self.adb_graph.vertex_collections() - edge_definitions = self.adb_graph.edge_definitions() - - if default_node_type is None: - default_node_type = list(vertex_collections)[0] - logger.info(f"Default node type set to '{default_node_type}'") - elif default_node_type not in vertex_collections: - m = f"Default node type '{default_node_type}' not found in graph '{name}'" # noqa: E501 - raise InvalidDefaultNodeType(m) - - node_types_to_edge_type_map: dict[tuple[str, str], str] = {} - for e_d in edge_definitions: - for f in e_d["from_vertex_collections"]: - for t in e_d["to_vertex_collections"]: - if (f, t) in node_types_to_edge_type_map: - # TODO: Should we log a warning at least? - continue - - node_types_to_edge_type_map[(f, t)] = e_d["edge_collection"] - - def edge_type_func(u: str, v: str) -> str: - try: - return node_types_to_edge_type_map[(u, v)] - except KeyError: - m = f"Edge type ambiguity between '{u}' and '{v}'" - raise EdgeTypeAmbiguity(m) - - self.edge_type_func = edge_type_func - self.default_node_type = default_node_type - + if self.graph_exists_in_db: self._set_factory_methods() self.__set_arangodb_backend_config(read_parallelism, read_batch_size) - elif self.__name: - - prefix = f"{name}_" if name else "" - if default_node_type is None: - default_node_type = f"{prefix}node" - if edge_type_func is None: - edge_type_func = lambda u, v: f"{u}_to_{v}" # noqa: E731 - - self.edge_type_func = edge_type_func - self.default_node_type = default_node_type - - # TODO: Parameterize the edge definitions - # How can we work with a heterogenous **incoming_graph_data**? - default_edge_type = edge_type_func(default_node_type, default_node_type) - edge_definitions = [ - { - "edge_collection": default_edge_type, - "from_vertex_collections": [default_node_type], - "to_vertex_collections": [default_node_type], - } - ] - - if isinstance(incoming_graph_data, nx.Graph): - self.adb_graph = ADBNX_Adapter(self.db).networkx_to_arangodb( - self.__name, - incoming_graph_data, - edge_definitions=edge_definitions, - batch_size=write_batch_size, - use_async=write_async, + if overwrite_graph: + logger.info("Overwriting graph...") + + properties = self.adb_graph.properties() + self.db.delete_graph(name, drop_collections=True) + self.db.create_graph( + name=name, + edge_definitions=properties["edge_definitions"], + orphan_collections=properties["orphan_collections"], + smart=properties.get("smart"), + disjoint=properties.get("disjoint"), + smart_field=properties.get("smart_field"), + shard_count=properties.get("shard_count"), + replication_factor=properties.get("replication_factor"), + write_concern=properties.get("write_concern"), ) + if isinstance(incoming_graph_data, nx.Graph): + self._load_nx_graph(incoming_graph_data, write_batch_size, write_async) self._loaded_incoming_graph_data = True - else: - self.adb_graph = self.db.create_graph( - self.__name, - edge_definitions=edge_definitions, - ) - - self._set_factory_methods() - self.__set_arangodb_backend_config(read_parallelism, read_batch_size) - logger.info(f"Graph '{name}' created.") - self.__graph_exists_in_db = True - - if self.__name is not None: - kwargs["name"] = self.__name + if name is not None: + kwargs["name"] = name super().__init__(*args, **kwargs) @@ -318,6 +262,7 @@ def edge_type_func(u: str, v: str) -> str: self.clear = self.clear_override self.clear_edges = self.clear_edges_override self.add_node = self.add_node_override + self.add_nodes_from = self.add_nodes_from_override self.number_of_edges = self.number_of_edges_override self.nbunch_iter = self.nbunch_iter_override @@ -333,6 +278,7 @@ def edge_type_func(u: str, v: str) -> str: and not self._loaded_incoming_graph_data ): nx.convert.to_networkx_graph(incoming_graph_data, create_using=self) + self._loaded_incoming_graph_data = True ####################### # Init helper methods # @@ -423,23 +369,131 @@ def __set_db(self, db: Any = None) -> None: self._db_name, self._username, self._password, verify=True ) - def __set_graph_name(self, name: Any = None) -> None: - if self.__db is None: - m = "Cannot set graph name without setting the database first" - raise DatabaseNotSet(m) - - if not name: - self.__graph_exists_in_db = False - logger.warning(f"**name** not set for {self.__class__.__name__}") - return - + def __set_graph( + self, + name: Any, + default_node_type: str | None = None, + edge_type_func: Callable[[str, str], str] | None = None, + ) -> None: if not isinstance(name, str): raise TypeError("**name** must be a string") + if self.db.has_graph(name): + logger.info(f"Graph '{name}' exists.") + + if edge_type_func is not None: + m = "Cannot pass **edge_type_func** if the graph already exists" + raise NotImplementedError(m) + + self.adb_graph = self.db.graph(name) + vertex_collections = self.adb_graph.vertex_collections() + edge_definitions = self.adb_graph.edge_definitions() + + if default_node_type is None: + default_node_type = list(vertex_collections)[0] + logger.info(f"Default node type set to '{default_node_type}'") + + elif default_node_type not in vertex_collections: + m = f"Default node type '{default_node_type}' not found in graph '{name}'" # noqa: E501 + raise InvalidDefaultNodeType(m) + + node_types_to_edge_type_map: dict[tuple[str, str], str] = {} + for e_d in edge_definitions: + for f in e_d["from_vertex_collections"]: + for t in e_d["to_vertex_collections"]: + if (f, t) in node_types_to_edge_type_map: + # TODO: Should we log a warning at least? + continue + + node_types_to_edge_type_map[(f, t)] = e_d["edge_collection"] + + def edge_type_func(u: str, v: str) -> str: + try: + return node_types_to_edge_type_map[(u, v)] + except KeyError: + m = f"Edge type ambiguity between '{u}' and '{v}'" + raise EdgeTypeAmbiguity(m) + + else: + prefix = f"{name}_" if name else "" + + if default_node_type is None: + default_node_type = f"{prefix}node" + + if edge_type_func is None: + edge_type_func = lambda u, v: f"{u}_to_{v}" # noqa: E731 + + # TODO: Parameterize the edge definitions + # How can we work with a heterogenous **incoming_graph_data**? + default_edge_type = edge_type_func(default_node_type, default_node_type) + edge_definitions = [ + { + "edge_collection": default_edge_type, + "from_vertex_collections": [default_node_type], + "to_vertex_collections": [default_node_type], + } + ] + + # Create a general graph if it doesn't exist + self.adb_graph = self.db.create_graph( + name=name, + edge_definitions=edge_definitions, + ) + + logger.info(f"Graph '{name}' created.") + self.__name = name - self.__graph_exists_in_db = self.db.has_graph(name) + self.__graph_exists_in_db = True + self.edge_type_func = edge_type_func + self.default_node_type = default_node_type + + properties = self.adb_graph.properties() + self.__is_smart: bool = properties.get("smart", False) + self.__smart_field: str | None = properties.get("smart_field") + + def _load_nx_graph( + self, nx_graph: nx.Graph, write_batch_size: int, write_async: bool + ) -> None: + collections = list(self.adb_graph.vertex_collections()) + collections += [e["edge_collection"] for e in self.adb_graph.edge_definitions()] + + for col in collections: + cursor = self.db.aql.execute( + "FOR doc IN @@collection LIMIT 1 RETURN 1", + bind_vars={"@collection": col}, + ) + + if not cursor.empty(): + m = f"Graph '{self.adb_graph.name}' already has data (in '{col}'). Use **overwrite_graph=True** to clear it." # noqa: E501 + raise GraphNotEmpty(m) + + controller = ADBNX_Controller + + if all([self.is_smart, self.smart_field]): + smart_field = self.__smart_field - logger.info(f"Graph '{name}' exists: {self.__graph_exists_in_db}") + class SmartController(ADBNX_Controller): + def _keyify_networkx_node( + self, i: int, nx_node_id: NxId, nx_node: NxData, col: str + ) -> str: + if smart_field not in nx_node: + m = f"Node {nx_node_id} missing smart field '{smart_field}'" # noqa: E501 + raise KeyError(m) + + return f"{nx_node[smart_field]}:{str(i)}" + + def _prepare_networkx_edge(self, nx_edge: NxData, col: str) -> None: + del nx_edge["_key"] + + controller = SmartController + logger.info(f"Using smart field '{smart_field}' for node keys") + + ADBNX_Adapter(self.db, controller()).networkx_to_arangodb( + self.adb_graph.name, + nx_graph, + batch_size=write_batch_size, + use_async=write_async, + ) ########### # Getters # @@ -479,6 +533,14 @@ def graph_exists_in_db(self) -> bool: def edge_attributes(self) -> set[str]: return self._edge_collections_attributes + @property + def is_smart(self) -> bool: + return self.__is_smart + + @property + def smart_field(self) -> str | None: + return self.__smart_field + ########### # Setters # ########### @@ -630,10 +692,10 @@ def clear_edges_override(self): nx._clear_cache(self) def add_node_override(self, node_for_adding, **attr): - if node_for_adding not in self._node: - if node_for_adding is None: - raise ValueError("None cannot be a node") + if node_for_adding is None: + raise ValueError("None cannot be a node") + if node_for_adding not in self._node: self._adj[node_for_adding] = self.adjlist_inner_dict_factory() ###################### @@ -645,12 +707,15 @@ def add_node_override(self, node_for_adding, **attr): # attr_dict.update(attr) # New: - self._node[node_for_adding] = self.node_attr_dict_factory() - self._node[node_for_adding].update(attr) + node_attr_dict = self.node_attr_dict_factory() + node_attr_dict.data = attr + self._node[node_for_adding] = node_attr_dict # Reason: - # Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set - # i.e trying to update a node's attributes before we know _which_ node it is + # We can optimize the process of adding a node by creating avoiding + # the creation of a new dictionary and updating it with the attributes. + # Instead, we can create a new node_attr_dict object and set the attributes + # directly. This only makes 1 network call to the database instead of 2. ########################### @@ -659,6 +724,48 @@ def add_node_override(self, node_for_adding, **attr): nx._clear_cache(self) + def add_nodes_from_override(self, nodes_for_adding, **attr): + for n in nodes_for_adding: + try: + newnode = n not in self._node + newdict = attr + except TypeError: + n, ndict = n + newnode = n not in self._node + newdict = attr.copy() + newdict.update(ndict) + if newnode: + if n is None: + raise ValueError("None cannot be a node") + self._adj[n] = self.adjlist_inner_dict_factory() + + ###################### + # NOTE: monkey patch # + ###################### + + # Old: + # self._node[n] = self.node_attr_dict_factory() + # + # self._node[n].update(newdict) + + # New: + node_attr_dict = self.node_attr_dict_factory() + node_attr_dict.data = newdict + self._node[n] = node_attr_dict + + else: + self._node[n].update(newdict) + + # Reason: + # We can optimize the process of adding a node by creating avoiding + # the creation of a new dictionary and updating it with the attributes. + # Instead, we create a new node_attr_dict object and set the attributes + # directly. This only makes 1 network call to the database instead of 2. + + ########################### + + nx._clear_cache(self) + def number_of_edges_override(self, u=None, v=None): if u is not None: return super().number_of_edges(u, v) diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py index dc05e59..f115ab8 100644 --- a/nx_arangodb/classes/multidigraph.py +++ b/nx_arangodb/classes/multidigraph.py @@ -70,8 +70,8 @@ class MultiDiGraph(MultiGraph, DiGraph, nx.MultiDiGraph): name : str (optional, default: None) Name of the graph in the database. If the graph already exists, the user can pass the name of the graph to connect to it. If - the graph does not exist, the user can create a new graph by - passing the name. NOTE: Must be used in conjunction with + the graph does not exist, a General Graph will be created by + passing the **name**. NOTE: Must be used in conjunction with **incoming_graph_data** if the user wants to persist the graph in ArangoDB. @@ -135,6 +135,12 @@ class MultiDiGraph(MultiGraph, DiGraph, nx.MultiDiGraph): whenever possible. NOTE: This feature is experimental and may not work as expected. + overwrite_graph : bool (optional, default: False) + Whether to overwrite the graph in the database if it already exists. If + set to True, the graph collections will be dropped and recreated. Note that + this operation is irreversible and will result in the loss of all data in + the graph. NOTE: If set to True, Collection Indexes will also be lost. + args: positional arguments for nx.Graph Additional arguments passed to nx.Graph. @@ -165,6 +171,7 @@ def __init__( write_async: bool = True, symmetrize_edges: bool = False, use_arango_views: bool = False, + overwrite_graph: bool = False, *args: Any, **kwargs: Any, ): @@ -183,6 +190,7 @@ def __init__( write_async, symmetrize_edges, use_arango_views, + overwrite_graph, *args, **kwargs, ) diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index 07c30b7..c494d34 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -71,8 +71,8 @@ class MultiGraph(Graph, nx.MultiGraph): name : str (optional, default: None) Name of the graph in the database. If the graph already exists, the user can pass the name of the graph to connect to it. If - the graph does not exist, the user can create a new graph by - passing the name. NOTE: Must be used in conjunction with + the graph does not exist, a General Graph will be created by + passing the **name**. NOTE: Must be used in conjunction with **incoming_graph_data** if the user wants to persist the graph in ArangoDB. @@ -136,6 +136,12 @@ class MultiGraph(Graph, nx.MultiGraph): whenever possible. NOTE: This feature is experimental and may not work as expected. + overwrite_graph : bool (optional, default: False) + Whether to overwrite the graph in the database if it already exists. If + set to True, the graph collections will be dropped and recreated. Note that + this operation is irreversible and will result in the loss of all data in + the graph. NOTE: If set to True, Collection Indexes will also be lost. + args: positional arguments for nx.Graph Additional arguments passed to nx.Graph. @@ -166,6 +172,7 @@ def __init__( write_async: bool = True, symmetrize_edges: bool = False, use_arango_views: bool = False, + overwrite_graph: bool = False, *args: Any, **kwargs: Any, ): @@ -183,6 +190,7 @@ def __init__( write_async, symmetrize_edges, use_arango_views, + overwrite_graph, *args, **kwargs, ) @@ -215,6 +223,8 @@ def __init__( else: nx.convert.to_networkx_graph(incoming_graph_data, create_using=self) + self._loaded_incoming_graph_data = True + ####################### # Init helper methods # ####################### diff --git a/nx_arangodb/exceptions.py b/nx_arangodb/exceptions.py index 35e538e..6df5be7 100644 --- a/nx_arangodb/exceptions.py +++ b/nx_arangodb/exceptions.py @@ -14,6 +14,10 @@ class GraphNameNotSet(NetworkXArangoDBException): pass +class GraphNotEmpty(NetworkXArangoDBException): + pass + + class InvalidTraversalDirection(NetworkXArangoDBException): pass diff --git a/pyproject.toml b/pyproject.toml index 7d181d3..b804121 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,9 +29,9 @@ classifiers = [ ] dependencies = [ "networkx>=3.0,<=3.3", - "phenolrs", - "python-arango", - "adbnx-adapter" + "phenolrs~=0.5", + "python-arango~=8.1", + "adbnx-adapter~=5.0.5" ] [project.optional-dependencies] diff --git a/tests/test.py b/tests/test.py index 6a19143..c454554 100644 --- a/tests/test.py +++ b/tests/test.py @@ -3,7 +3,7 @@ import networkx as nx import pytest -from arango import DocumentDeleteError +from arango.exceptions import DocumentInsertError from phenolrs.networkx.typings import ( DiGraphAdjDict, GraphAdjDict, @@ -17,7 +17,7 @@ from .conftest import create_grid_graph, create_line_graph, db, run_gpu_tests -G_NX = nx.karate_club_graph() +G_NX: nx.Graph = nx.karate_club_graph() G_NX_digraph = nx.DiGraph(G_NX) G_NX_multigraph = nx.MultiGraph(G_NX) G_NX_multidigraph = nx.MultiDiGraph(G_NX) @@ -108,6 +108,87 @@ def test_load_graph_from_nxadb(): db.delete_graph(graph_name, drop_collections=True) +def test_load_graph_from_nxadb_as_smart_graph(): + graph_name = "SmartKarateGraph" + + db.delete_graph(graph_name, drop_collections=True, ignore_missing=True) + db.create_graph( + graph_name, + smart=True, + smart_field="club", + edge_definitions=[ + { + "edge_collection": "smart_person_to_smart_person", + "from_vertex_collections": ["smart_person"], + "to_vertex_collections": ["smart_person"], + } + ], + ) + + # Small preprocessing to remove whitespaces from club names, + # as smart graphs do not allow whitespaces in smart fields + G_NX_copy = G_NX.copy() + for _, node in G_NX_copy.nodes(data=True): + node["club"] = node["club"].replace(" ", "") + + G = nxadb.Graph( + name=graph_name, + incoming_graph_data=G_NX_copy, + write_async=False, + ) + + assert db.has_graph(graph_name) + assert db.has_collection("smart_person") + assert db.has_collection("smart_person_to_smart_person") + assert db.collection("smart_person").count() == len(G_NX_copy.nodes) + assert db.collection("smart_person_to_smart_person").count() == len(G_NX_copy.edges) + + assert db.has_document("smart_person/Mr.Hi:0") + + with pytest.raises(DocumentInsertError): + G.add_node(35, club="Officer") + + with pytest.raises(DocumentInsertError): + G.add_node("35", club="Officer") + + with pytest.raises(DocumentInsertError): + G.add_node("smart_person/35", club="Officer") + + with pytest.raises(DocumentInsertError): + G.add_node("smart_person/Officer:35", club="officer") + + with pytest.raises(DocumentInsertError): + G.add_node("smart_person/Officer", club="Officer") + + assert G.nodes["Mr.Hi:0"]["club"] == "Mr.Hi" + G.add_node("Officer:35", club="Officer") + assert G.nodes["smart_person/Officer:35"]["club"] == "Officer" + + assert G["Mr.Hi:0"]["Mr.Hi:1"]["weight"] == 4 + G.add_edge("Mr.Hi:0", "Officer:35", weight=5) + assert G["Mr.Hi:0"]["Officer:35"]["weight"] == 5 + + G.add_nodes_from( + [("Officer:36", {"club": "Officer"}), ("Mr.Hi:37", {"club": "Mr.Hi"})] + ) + assert G.has_node("Officer:36") + assert G.has_node("Mr.Hi:37") + + assert db.collection("smart_person").properties()["smart"] + + G = nxadb.Graph( + name=graph_name, + incoming_graph_data=G_NX_copy, + write_async=False, + overwrite_graph=True, + ) + + assert db.collection("smart_person").properties()["smart"] + assert G.nodes["Mr.Hi:0"]["club"] == "Mr.Hi" + + db.delete_graph(graph_name, drop_collections=True) + + def test_load_graph_from_nxadb_w_specific_edge_attribute(): graph_name = "KarateGraph"