Skip to content

Commit

Permalink
checkpoint (CI is failing)
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed May 4, 2024
1 parent 45a1e20 commit 5ed8afa
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 44 deletions.
87 changes: 65 additions & 22 deletions nx_arangodb/classes/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable

from arango.database import StandardDatabase
from arango.exceptions import DocumentInsertError
from arango.graph import Graph

from .function import (
Expand All @@ -14,7 +15,9 @@
aql_doc_get_keys,
aql_doc_get_length,
aql_doc_has_key,
aql_edge_exists,
aql_edge_get,
aql_edge_id,
aql_single,
create_collection,
doc_delete,
Expand Down Expand Up @@ -84,7 +87,7 @@ class GraphDict(UserDict):
:type graph_name: str
"""

COLLECTION_NAME = "NXADB_GRAPH_ATTRIBUTES"
COLLECTION_NAME = "nxadb_graphs"

def __init__(self, db: StandardDatabase, graph_name: str, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -134,8 +137,8 @@ def __setitem__(self, key: str, value: Any):
@key_is_not_reserved
def __delitem__(self, key):
"""del G.graph['foo']"""
del self.data[key]
doc_update(self.db, self.graph_id, {key: None})
self.data.pop(key, None)

@keys_are_strings
@keys_are_not_reserved
Expand Down Expand Up @@ -196,8 +199,8 @@ def __getitem__(self, key: str) -> NodeAttrDict:
"""G._node['node/1']"""
node_id = get_node_id(key, self.default_node_type)

if node_id in self.data:
return self.data[node_id]
if value := self.data.get(node_id):
return value

if value := self.graph.vertex(node_id):
node_attr_dict: NodeAttrDict = self.node_attr_dict_factory()
Expand Down Expand Up @@ -234,8 +237,8 @@ def __delitem__(self, key: Any) -> None:
"""del g._node['node/1']"""
node_id = get_node_id(key, self.default_node_type)

del self.data[node_id]
doc_delete(self.db, node_id)
self.data.pop(node_id, None)

def __len__(self) -> int:
"""len(g._node)"""
Expand Down Expand Up @@ -378,8 +381,8 @@ def __setitem__(self, key: str, value: Any):
@key_is_not_reserved
def __delitem__(self, key: str):
"""del G._node['node/1']['foo']"""
del self.data[key]
doc_update(self.db, self.node_id, {key: None})
self.data.pop(key, None)

def __iter__(self) -> Iterator[str]:
"""for key in G._node['node/1']"""
Expand Down Expand Up @@ -485,8 +488,8 @@ def __getitem__(self, key) -> AdjListInnerDict:
"""G.adj["node/1"]"""
node_type, node_id = get_node_type_and_id(key, self.default_node_type)

if node_id in self.data:
return self.data[node_id]
if value := self.data.get(node_id):
return value

if self.graph.has_vertex(node_id):
adjlist_inner_dict: AdjListInnerDict = self.adjlist_inner_dict_factory()
Expand Down Expand Up @@ -520,12 +523,12 @@ def __setitem__(self, src_key: str, adjlist_inner_dict: AdjListInnerDict):
dst_key, self.default_node_type
)

edge_type = edge_dict.get("_edge_type") # pop?
edge_type = edge_dict.get("_edge_type")
if edge_type is None:
edge_type = self.edge_type_func(src_node_type, dst_node_type)

results[dst_key] = self.graph.link(
edge_type, src_node_id, dst_node_id, edge_dict, silent=True
edge_type, src_node_id, dst_node_id, edge_dict
)

adjlist_inner_dict.src_node_id = src_node_id
Expand All @@ -539,7 +542,24 @@ def __delitem__(self, key: Any) -> None:
"""
del G._adj['node/1']
"""
raise NotImplementedError("AdjListOuterDict.__delitem__()")
node_id = get_node_id(key, self.default_node_type)

if not self.graph.has_vertex(node_id):
return

remove_statements = "\n".join(
f"REMOVE e IN `{edge_def['edge_collection']}` OPTIONS {{ignoreErrors: true}}"
for edge_def in self.graph.edge_definitions()
)

query = f"""
FOR v, e IN 1..1 OUTBOUND @src_node_id GRAPH @graph_name
{remove_statements}
"""

bind_vars = {"src_node_id": node_id, "graph_name": self.graph.name}

aql(self.db, query, bind_vars)

def __len__(self) -> int:
"""len(g._adj)"""
Expand Down Expand Up @@ -584,7 +604,7 @@ def update(self, edges: dict[str, dict[str, dict[str, Any]]]):
for dst_key, edge_dict in dst_dict.items():
dst_node_type, dst_node_id = get_node_type_and_id(dst_key)

edge_type = edge_dict.get("_edge_type") # pop?
edge_type = edge_dict.get("_edge_type")
if edge_type is None:
edge_type = self.edge_type_func(src_node_type, dst_node_type)

Expand Down Expand Up @@ -693,17 +713,14 @@ def __contains__(self, key) -> bool:
if dst_node_id in self.data:
return True

return aql_edge_get(
return aql_edge_exists(
self.db,
self.src_node_id,
dst_node_id,
self.graph.name,
direction="OUTBOUND",
return_bool=True,
)

# CHECKPOINT...

@key_is_string
def __getitem__(self, key) -> EdgeAttrDict:
"""g._adj['node/1']['node/2']"""
Expand All @@ -718,7 +735,6 @@ def __getitem__(self, key) -> EdgeAttrDict:
dst_node_id,
self.graph.name,
direction="OUTBOUND",
return_bool=False,
)

if not edge:
Expand All @@ -739,11 +755,24 @@ def __setitem__(self, key: str, value: dict | EdgeAttrDict):

dst_node_type, dst_node_id = get_node_type_and_id(key, self.default_node_type)

edge_type = value.data.get("_edge_type") # pop?
edge_type = value.data.get("_edge_type")
if edge_type is None:
edge_type = self.edge_type_func(self.src_node_type, dst_node_type)

data = value.data

if edge_id := value.edge_id:
self.graph.delete_edge(edge_id)

elif edge_id := aql_edge_id(
self.db,
self.src_node_id,
dst_node_id,
self.graph.name,
direction="OUTBOUND",
):
self.graph.delete_edge(edge_id)

edge = self.graph.link(edge_type, self.src_node_id, dst_node_id, data)

edge_attr_dict = self.edge_attr_dict_factory()
Expand All @@ -755,7 +784,21 @@ def __setitem__(self, key: str, value: dict | EdgeAttrDict):
@key_is_string
def __delitem__(self, key: Any) -> None:
"""del g._adj['node/1']['node/2']"""
raise NotImplementedError("AdjListInnerDict.__delitem__()")
dst_node_id = get_node_id(key, self.default_node_type)

edge_id = aql_edge_id(
self.db,
self.src_node_id,
dst_node_id,
self.graph.name,
direction="OUTBOUND",
)

if not edge_id:
return

self.graph.delete_edge(edge_id)
self.data.pop(dst_node_id, None)

def __len__(self) -> int:
"""len(g._adj['node/1'])"""
Expand Down Expand Up @@ -895,8 +938,8 @@ def __setitem__(self, key: str, value: Any):
@key_is_not_reserved
def __delitem__(self, key: str):
"""del G._adj['node/1']['node/2']['foo']"""
del self.data[key]
doc_update(self.db, self.node_id, {key: None})
doc_update(self.db, self.edge_id, {key: None})
self.data.pop(key, None)

def __iter__(self) -> Iterator[str]:
"""for key in G._adj['node/1']['node/2']"""
Expand All @@ -922,7 +965,7 @@ def values(self, cache: bool = True):

def items(self, cache: bool = True):
"""G._adj['node/1']['node/'2].items()"""
doc = self.db.document(self.node_id)
doc = self.db.document(self.edge_id)

if cache:
self.data = doc
Expand Down
17 changes: 17 additions & 0 deletions nx_arangodb/classes/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class NetworkXArangoDBException(Exception):
pass


EDGE_ALREADY_EXISTS_ERROR_CODE = 1210


class EdgeAlreadyExists(NetworkXArangoDBException):
"""Raised when trying to add an edge that already exists in the graph."""

pass


class AQLMultipleResultsFound(NetworkXArangoDBException):
"""Raised when multiple results are returned from a query that was expected to return a single result."""

pass
Loading

0 comments on commit 5ed8afa

Please sign in to comment.