Skip to content

Commit

Permalink
smart graph support | initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Sep 23, 2024
1 parent 7999151 commit 152fe62
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 100 deletions.
14 changes: 14 additions & 0 deletions nx_arangodb/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(
and not self._loaded_incoming_graph_data
):
nx.convert.to_networkx_graph(incoming_graph_data, create_using=self)
self._loaded_incoming_graph_data = True

#######################
# nx.DiGraph Overides #
Expand Down Expand Up @@ -241,12 +242,25 @@ def add_node_override(self, node_for_adding, **attr):
# attr_dict.update(attr)

# New:

node_attr_dict = self.node_attr_dict_factory()

if self.is_smart:
if self.smart_field not in attr:
m = f"Node {node_for_adding} missing smart field '{self.smart_field}'" # noqa: E501
raise KeyError(m)

node_attr_dict.data[self.smart_field] = attr[self.smart_field]

self._node[node_for_adding] = self.node_attr_dict_factory()
self._node[node_for_adding].update(attr)

# Reason:
# Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set
# i.e trying to update a node's attributes before we know _which_ node it is
# Furthermore, support for ArangoDB Smart Graphs requires the smart field
# to be set before adding the node to the graph. This is because the smart
# field is used to generate the node's key.

###########################

Expand Down
245 changes: 145 additions & 100 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Any, Callable, ClassVar

import networkx as nx
from adbnx_adapter import ADBNX_Adapter
from adbnx_adapter import ADBNX_Adapter, ADBNX_Controller
from adbnx_adapter.typings import NxData, NxId
from arango import ArangoClient
from arango.cursor import Cursor
from arango.database import StandardDatabase
Expand Down Expand Up @@ -186,19 +187,18 @@ def __init__(
write_async: bool = True,
symmetrize_edges: bool = False,
use_arango_views: bool = False,
overwrite_graph: bool = False,
*args: Any,
**kwargs: Any,
):
self.__db = None
self.__name = None
self.__use_arango_views = use_arango_views
self.__graph_exists_in_db = False

self.__set_db(db)
if self.__db is not None:
self.__set_graph_name(name)

self.__set_edge_collections_attributes(edge_collections_attributes)
if all([self.__db, name]):
self.__set_graph(name, default_node_type, edge_type_func)
self.__set_edge_collections_attributes(edge_collections_attributes)

# NOTE: Need to revisit these...
# self.maintain_node_dict_cache = False
Expand All @@ -219,96 +219,25 @@ def __init__(
# raise ValueError(m)

self._loaded_incoming_graph_data = False

if self.__graph_exists_in_db:
if incoming_graph_data is not None:
m = "Cannot pass both **incoming_graph_data** and **name** yet if the already graph exists" # noqa: E501
raise NotImplementedError(m)

if edge_type_func is not None:
m = "Cannot pass **edge_type_func** if the graph already exists"
raise NotImplementedError(m)

self.adb_graph = self.db.graph(self.__name)
vertex_collections = self.adb_graph.vertex_collections()
edge_definitions = self.adb_graph.edge_definitions()

if default_node_type is None:
default_node_type = list(vertex_collections)[0]
logger.info(f"Default node type set to '{default_node_type}'")
elif default_node_type not in vertex_collections:
m = f"Default node type '{default_node_type}' not found in graph '{name}'" # noqa: E501
raise InvalidDefaultNodeType(m)

node_types_to_edge_type_map: dict[tuple[str, str], str] = {}
for e_d in edge_definitions:
for f in e_d["from_vertex_collections"]:
for t in e_d["to_vertex_collections"]:
if (f, t) in node_types_to_edge_type_map:
# TODO: Should we log a warning at least?
continue

node_types_to_edge_type_map[(f, t)] = e_d["edge_collection"]

def edge_type_func(u: str, v: str) -> str:
try:
return node_types_to_edge_type_map[(u, v)]
except KeyError:
m = f"Edge type ambiguity between '{u}' and '{v}'"
raise EdgeTypeAmbiguity(m)

self.edge_type_func = edge_type_func
self.default_node_type = default_node_type

if self.graph_exists_in_db:
self._set_factory_methods()
self.__set_arangodb_backend_config(read_parallelism, read_batch_size)

elif self.__name:
if overwrite_graph:
logger.info("Truncating graph collections...")

prefix = f"{name}_" if name else ""
if default_node_type is None:
default_node_type = f"{prefix}node"
if edge_type_func is None:
edge_type_func = lambda u, v: f"{u}_to_{v}" # noqa: E731
for col in self.adb_graph.vertex_collections():
self.db.collection(col).truncate()

self.edge_type_func = edge_type_func
self.default_node_type = default_node_type

# TODO: Parameterize the edge definitions
# How can we work with a heterogenous **incoming_graph_data**?
default_edge_type = edge_type_func(default_node_type, default_node_type)
edge_definitions = [
{
"edge_collection": default_edge_type,
"from_vertex_collections": [default_node_type],
"to_vertex_collections": [default_node_type],
}
]
for col in self.adb_graph.edge_definitions():
self.db.collection(col["edge_collection"]).truncate()

if isinstance(incoming_graph_data, nx.Graph):
self.adb_graph = ADBNX_Adapter(self.db).networkx_to_arangodb(
self.__name,
incoming_graph_data,
edge_definitions=edge_definitions,
batch_size=write_batch_size,
use_async=write_async,
)

self._load_nx_graph(incoming_graph_data, write_batch_size, write_async)
self._loaded_incoming_graph_data = True

else:
self.adb_graph = self.db.create_graph(
self.__name,
edge_definitions=edge_definitions,
)

self._set_factory_methods()
self.__set_arangodb_backend_config(read_parallelism, read_batch_size)
logger.info(f"Graph '{name}' created.")
self.__graph_exists_in_db = True

if self.__name is not None:
kwargs["name"] = self.__name
if name is not None:
kwargs["name"] = name

super().__init__(*args, **kwargs)

Expand All @@ -333,6 +262,7 @@ def edge_type_func(u: str, v: str) -> str:
and not self._loaded_incoming_graph_data
):
nx.convert.to_networkx_graph(incoming_graph_data, create_using=self)
self._loaded_incoming_graph_data = True

#######################
# Init helper methods #
Expand Down Expand Up @@ -423,23 +353,118 @@ def __set_db(self, db: Any = None) -> None:
self._db_name, self._username, self._password, verify=True
)

def __set_graph_name(self, name: Any = None) -> None:
if self.__db is None:
m = "Cannot set graph name without setting the database first"
raise DatabaseNotSet(m)

if not name:
self.__graph_exists_in_db = False
logger.warning(f"**name** not set for {self.__class__.__name__}")
return

def __set_graph(
self,
name: Any,
default_node_type: str | None = None,
edge_type_func: Callable[[str, str], str] | None = None,
) -> None:
if not isinstance(name, str):
raise TypeError("**name** must be a string")

if self.db.has_graph(name):
logger.info(f"Graph '{name}' exists.")

if edge_type_func is not None:
m = "Cannot pass **edge_type_func** if the graph already exists"
raise NotImplementedError(m)

self.adb_graph = self.db.graph(name)
vertex_collections = self.adb_graph.vertex_collections()
edge_definitions = self.adb_graph.edge_definitions()

if default_node_type is None:
default_node_type = list(vertex_collections)[0]
logger.info(f"Default node type set to '{default_node_type}'")

elif default_node_type not in vertex_collections:
m = f"Default node type '{default_node_type}' not found in graph '{name}'" # noqa: E501
raise InvalidDefaultNodeType(m)

node_types_to_edge_type_map: dict[tuple[str, str], str] = {}
for e_d in edge_definitions:
for f in e_d["from_vertex_collections"]:
for t in e_d["to_vertex_collections"]:
if (f, t) in node_types_to_edge_type_map:
# TODO: Should we log a warning at least?
continue

node_types_to_edge_type_map[(f, t)] = e_d["edge_collection"]

def edge_type_func(u: str, v: str) -> str:
try:
return node_types_to_edge_type_map[(u, v)]
except KeyError:
m = f"Edge type ambiguity between '{u}' and '{v}'"
raise EdgeTypeAmbiguity(m)

else:
prefix = f"{name}_" if name else ""

if default_node_type is None:
default_node_type = f"{prefix}node"

if edge_type_func is None:
edge_type_func = lambda u, v: f"{u}_to_{v}" # noqa: E731

# TODO: Parameterize the edge definitions
# How can we work with a heterogenous **incoming_graph_data**?
default_edge_type = edge_type_func(default_node_type, default_node_type)
edge_definitions = [
{
"edge_collection": default_edge_type,
"from_vertex_collections": [default_node_type],
"to_vertex_collections": [default_node_type],
}
]

# Create a general graph if it doesn't exist
self.adb_graph = self.db.create_graph(
name=name,
edge_definitions=edge_definitions,
)

logger.info(f"Graph '{name}' created.")

self.__name = name
self.__graph_exists_in_db = self.db.has_graph(name)
self.__graph_exists_in_db = True
self.edge_type_func = edge_type_func
self.default_node_type = default_node_type

properties = self.adb_graph.properties()
self.__is_smart: bool = properties.get("smart", False)
self.__smart_field: str | None = properties.get("smart_field")

def _load_nx_graph(
self, nx_graph: nx.Graph, write_batch_size: int, write_async: bool
) -> None:
controller = ADBNX_Controller

if all([self.is_smart, self.smart_field]):
smart_field = self.__smart_field

logger.info(f"Graph '{name}' exists: {self.__graph_exists_in_db}")
class SmartController(ADBNX_Controller):
def _keyify_networkx_node(
self, i: int, nx_node_id: NxId, nx_node: NxData, col: str
) -> str:
if smart_field not in nx_node:
m = f"Node {nx_node_id} missing smart field '{smart_field}'" # noqa: E501
raise KeyError(m)

return f"{nx_node[smart_field]}:{str(i)}"

def _prepare_networkx_edge(self, nx_edge: NxData, col: str) -> None:
del nx_edge["_key"]

controller = SmartController
logger.info(f"Using smart field '{smart_field}' for node keys")

ADBNX_Adapter(self.db, controller()).networkx_to_arangodb(
self.adb_graph.name,
nx_graph,
batch_size=write_batch_size,
use_async=write_async,
)

###########
# Getters #
Expand Down Expand Up @@ -479,6 +504,14 @@ def graph_exists_in_db(self) -> bool:
def edge_attributes(self) -> set[str]:
return self._edge_collections_attributes

@property
def is_smart(self) -> bool:
return self.__is_smart

@property
def smart_field(self) -> str | None:
return self.__smart_field

###########
# Setters #
###########
Expand Down Expand Up @@ -645,12 +678,24 @@ def add_node_override(self, node_for_adding, **attr):
# attr_dict.update(attr)

# New:
self._node[node_for_adding] = self.node_attr_dict_factory()
node_attr_dict = self.node_attr_dict_factory()

if self.is_smart:
if self.smart_field not in attr:
m = f"Node {node_for_adding} missing smart field '{self.smart_field}'" # noqa: E501
raise KeyError(m)

node_attr_dict.data[self.smart_field] = attr[self.smart_field]

self._node[node_for_adding] = node_attr_dict
self._node[node_for_adding].update(attr)

# Reason:
# Invoking `update` on the `attr_dict` without `attr_dict.node_id` being set
# i.e trying to update a node's attributes before we know _which_ node it is
# Furthermore, support for ArangoDB Smart Graphs requires the smart field
# to be set before adding the node to the graph. This is because the smart
# field is used to generate the node's key.

###########################

Expand Down
2 changes: 2 additions & 0 deletions nx_arangodb/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def __init__(
else:
nx.convert.to_networkx_graph(incoming_graph_data, create_using=self)

self._loaded_incoming_graph_data = True

#######################
# Init helper methods #
#######################
Expand Down

0 comments on commit 152fe62

Please sign in to comment.