Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

smart graph support #61

Merged
merged 26 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions nx_arangodb/classes/dict/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def update(self, attrs: Any) -> None:
if not attrs:
return

self.data.update(build_node_attr_dict_data(self, attrs))
node_attr_dict_data = build_node_attr_dict_data(self, attrs)
self.data.update(node_attr_dict_data)

if not self.node_id:
logger.debug("Node ID not set, skipping NodeAttrDict(?).update()")
Expand Down Expand Up @@ -275,10 +276,12 @@ def __init__(
self.FETCHED_ALL_DATA = False
self.FETCHED_ALL_IDS = False

def _create_node_attr_dict(self, vertex: dict[str, Any]) -> NodeAttrDict:
def _create_node_attr_dict(
self, node_id: str, node_data: dict[str, Any]
) -> NodeAttrDict:
node_attr_dict = self.node_attr_dict_factory()
node_attr_dict.node_id = vertex["_id"]
node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, vertex)
node_attr_dict.node_id = node_id
node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data)

return node_attr_dict

Expand Down Expand Up @@ -322,7 +325,7 @@ def __getitem__(self, key: str) -> NodeAttrDict:
raise KeyError(key)

if vertex_db := vertex_get(self.graph, node_id):
node_attr_dict = self._create_node_attr_dict(vertex_db)
node_attr_dict = self._create_node_attr_dict(vertex_db["_id"], vertex_db)
aMahanna marked this conversation as resolved.
Show resolved Hide resolved
self.data[node_id] = node_attr_dict

return node_attr_dict
Expand All @@ -331,18 +334,16 @@ def __getitem__(self, key: str) -> NodeAttrDict:

@key_is_string
def __setitem__(self, key: str, value: NodeAttrDict) -> None:
"""G._node['node/1'] = {'foo': 'bar'}

Not to be confused with:
- G.add_node('node/1', foo='bar')
"""
"""G._node['node/1'] = {'foo': 'bar'}"""
assert isinstance(value, NodeAttrDict)

node_type, node_id = get_node_type_and_id(key, self.default_node_type)

result = doc_insert(self.db, node_type, node_id, value.data)

node_attr_dict = self._create_node_attr_dict(result)
node_attr_dict = self._create_node_attr_dict(
result["_id"], {**value.data, **result}
)

self.data[node_id] = node_attr_dict

Expand Down Expand Up @@ -405,10 +406,7 @@ def copy(self) -> Any:
@keys_are_strings
def __update_local_nodes(self, nodes: Any) -> None:
for node_id, node_data in nodes.items():
node_attr_dict = self.node_attr_dict_factory()
node_attr_dict.node_id = node_id
node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data)

node_attr_dict = self._create_node_attr_dict(node_id, node_data)
self.data[node_id] = node_attr_dict

@keys_are_strings
Expand Down Expand Up @@ -478,7 +476,7 @@ def _fetch_all(self):

for node_id, node_data in node_dict.items():
del node_data["_rev"] # TODO: Optimize away via phenolrs
node_attr_dict = self._create_node_attr_dict(node_data)
node_attr_dict = self._create_node_attr_dict(node_data["_id"], node_data)
self.data[node_id] = node_attr_dict

self.FETCHED_ALL_DATA = True
Expand Down
70 changes: 64 additions & 6 deletions nx_arangodb/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class DiGraph(Graph, nx.DiGraph):
whenever possible. NOTE: This feature is experimental and may not work
as expected.

overwrite_graph : bool (optional, default: False)
Whether to truncate the graph collections when the graph is loaded from
the database. If set to True, the graph collections will be truncated
before loading the graph data. NOTE: This parameter only applies if the
graph already exists in the database.

args: positional arguments for nx.Graph
Additional arguments passed to nx.Graph.

Expand Down Expand Up @@ -154,6 +160,7 @@ def __init__(
write_async: bool = True,
symmetrize_edges: bool = False,
use_arango_views: bool = False,
overwrite_graph: bool = False,
*args: Any,
**kwargs: Any,
):
Expand All @@ -171,13 +178,15 @@ def __init__(
write_async,
symmetrize_edges,
use_arango_views,
overwrite_graph,
*args,
**kwargs,
)

if self.graph_exists_in_db:
self.clear_edges = self.clear_edges_override
self.add_node = self.add_node_override
self.add_nodes_from = self.add_nodes_from_override
self.remove_node = self.remove_node_override
self.reverse = self.reverse_override

Expand All @@ -194,6 +203,7 @@ def __init__(
and not self._loaded_incoming_graph_data
):
nx.convert.to_networkx_graph(incoming_graph_data, create_using=self)
self._loaded_incoming_graph_data = True

#######################
# nx.DiGraph Overides #
Expand Down Expand Up @@ -225,9 +235,10 @@ def clear_edges_override(self):
super().clear_edges()

def add_node_override(self, node_for_adding, **attr):
if node_for_adding is None:
raise ValueError("None cannot be a node")

if node_for_adding not in self._succ:
if node_for_adding is None:
raise ValueError("None cannot be a node")

self._succ[node_for_adding] = self.adjlist_inner_dict_factory()
self._pred[node_for_adding] = self.adjlist_inner_dict_factory()
Expand All @@ -241,12 +252,15 @@ def add_node_override(self, node_for_adding, **attr):
# attr_dict.update(attr)

# New:
self._node[node_for_adding] = self.node_attr_dict_factory()
self._node[node_for_adding].update(attr)
node_attr_dict = self.node_attr_dict_factory()
node_attr_dict.data = attr
self._node[node_for_adding] = node_attr_dict

# Reason:
# Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set
# i.e trying to update a node's attributes before we know _which_ node it is
# We can optimize the process of adding a node by creating avoiding
# the creation of a new dictionary and updating it with the attributes.
# Instead, we can create a new node_attr_dict object and set the attributes
# directly. This only makes 1 network call to the database instead of 2.

###########################

Expand All @@ -255,6 +269,50 @@ def add_node_override(self, node_for_adding, **attr):

nx._clear_cache(self)

def add_nodes_from_override(self, nodes_for_adding, **attr):
for n in nodes_for_adding:
aMahanna marked this conversation as resolved.
Show resolved Hide resolved
try:
newnode = n not in self._node
newdict = attr
except TypeError:
n, ndict = n
newnode = n not in self._node
newdict = attr.copy()
newdict.update(ndict)
if newnode:
if n is None:
raise ValueError("None cannot be a node")
self._succ[n] = self.adjlist_inner_dict_factory()
self._pred[n] = self.adjlist_inner_dict_factory()

######################
# NOTE: monkey patch #
######################

# Old:
# self._node[n] = self.node_attr_dict_factory()
#
# self._node[n].update(newdict)

# New:
node_attr_dict = self.node_attr_dict_factory()
node_attr_dict.data = newdict
self._node[n] = node_attr_dict

else:

aMahanna marked this conversation as resolved.
Show resolved Hide resolved
self._node[n].update(newdict)

# Reason:
# We can optimize the process of adding a node by creating avoiding
# the creation of a new dictionary and updating it with the attributes.
# Instead, we create a new node_attr_dict object and set the attributes
# directly. This only makes 1 network call to the database instead of 2.

###########################

nx._clear_cache(self)

def remove_node_override(self, n):
if isinstance(n, (str, int)):
n = get_node_id(str(n), self.default_node_type)
Expand Down
27 changes: 14 additions & 13 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,17 @@ def to_dict(self):
return cls


def cast_to_string(value: Any) -> str:
"""Casts a value to a string."""
if isinstance(value, str):
return value

if isinstance(value, (int, float)):
return str(value)

raise TypeError(f"{value} cannot be casted to string.")


def key_is_string(func: Callable[..., Any]) -> Any:
"""Decorator to check if the key is a string.
Will attempt to cast the key to a string if it is not.
Expand All @@ -208,12 +219,7 @@ def wrapper(self: Any, key: Any, *args: Any, **kwargs: Any) -> Any:
if key is None:
raise ValueError("Key cannot be None.")

if not isinstance(key, str):
if not isinstance(key, (int, float)):
raise TypeError(f"{key} cannot be casted to string.")

key = str(key)

key = cast_to_string(key)
return func(self, key, *args, **kwargs)

return wrapper
Expand Down Expand Up @@ -270,12 +276,7 @@ def wrapper(self: Any, data: Any, *args: Any, **kwargs: Any) -> Any:
raise TypeError(f"Decorator found unsupported type: {type(data)}.")

for key, value in items:
if not isinstance(key, str):
if not isinstance(key, (int, float)):
raise TypeError(f"{key} cannot be casted to string.")

key = str(key)

key = cast_to_string(key)
data_dict[key] = value

return func(self, data_dict, *args, **kwargs)
Expand Down Expand Up @@ -655,7 +656,7 @@ def doc_insert(
data: dict[str, Any] = {},
**kwargs: Any,
) -> dict[str, Any]:
"""Inserts a document into a collection."""
"""Inserts a document into a collection. Returns document metadata."""
result: dict[str, Any] = db.insert_document(
collection, {**data, "_id": id}, overwrite=True, **kwargs
)
Expand Down
Loading
Loading