From c949e97c206d2df557f81da847ddf6a5dc633f45 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Thu, 7 Mar 2024 17:38:18 -0500 Subject: [PATCH] Store position attribute as Array type The list of position attributes has now been replaced with a single attribute, which can be an Array. Postgre will store the array directly, SQLite will create one column for each element. This commit also tightens the requirements for optional arguments passed to the graph databases and simplifies the meta-data checks. There is now also documentation on the allowed read/write modes and checks for the passed arguments. --- .../graphs/pgsql_graph_database.py | 13 +- .../persistence/graphs/sql_graph_database.py | 288 +++++++++++------- .../graphs/sqlite_graph_database.py | 115 ++++++- tests/conftest.py | 4 +- tests/test_graph.py | 193 +++++++----- 5 files changed, 418 insertions(+), 195 deletions(-) diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py index ec40146..6c701c5 100644 --- a/funlib/persistence/graphs/pgsql_graph_database.py +++ b/funlib/persistence/graphs/pgsql_graph_database.py @@ -14,7 +14,7 @@ class PgSQLGraphDatabase(SQLGraphDataBase): def __init__( self, - position_attributes: list[str], + position_attribute: str, db_name: str, db_host: str = "localhost", db_user: Optional[str] = None, @@ -62,8 +62,8 @@ def __init__( self.cur = self.connection.cursor() super().__init__( - position_attributes, mode=mode, + position_attribute=position_attribute, directed=directed, total_roi=total_roi, nodes_table=nodes_table, @@ -86,10 +86,8 @@ def _drop_tables(self) -> None: self._commit() def _create_tables(self) -> None: - columns = self.position_attributes + list(self.node_attrs.keys()) - types = [self.__sql_type(float) + " NOT NULL"] * len( - self.position_attributes - ) + list([self.__sql_type(t) for t in self.node_attrs.values()]) + columns = self.node_attrs.keys() + types = [self.__sql_type(t) for t in self.node_attrs.values()] column_types = [f"{c} {t}" for c, t in zip(columns, types)] self.__exec( f"CREATE TABLE IF NOT EXISTS " @@ -100,7 +98,7 @@ def _create_tables(self) -> None: ) self.__exec( f"CREATE INDEX IF NOT EXISTS pos_index ON " - f"{self.nodes_table_name}({','.join(self.position_attributes)})" + f"{self.nodes_table_name}({self.position_attribute})" ) columns = list(self.edge_attrs.keys()) @@ -139,6 +137,7 @@ def _read_metadata(self) -> Optional[dict[str, Any]]: return None def _select_query(self, query) -> Iterable[Any]: + print(query) self.__exec(query) return self.cur diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index 9627623..f3af09b 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -17,16 +17,21 @@ class SQLGraphDataBase(GraphDataBase): """Base class for SQL-based graph databases. - Nodes must have position attributes (set via argument - ``position_attributes``), which will be used for geometric slicing (see + Nodes must have a position attribute (set via argument + ``position_attribute``), which will be used for geometric slicing (see ``__getitem__`` and ``read_graph``). Arguments: - position_attributes (list of ``string``s): + mode (``string``): - The node attributes that contain position information. This will - be used for slicing subgraphs via ``__getitem__``. + Any of ``r`` (read-only), ``r+`` (read and allow modifications), + or ``w`` (create new database, overwrite if exists). + + position_attribute (``string``): + + The node attribute that contains position information. This will be + used for slicing subgraphs via ``__getitem__``. directed (``bool``): @@ -56,38 +61,95 @@ class SQLGraphDataBase(GraphDataBase): The custom attributes to store on each edge. """ + read_modes = ["r", "r+"] + write_modes = ["r+", "w"] + create_modes = ["w"] + valid_modes = ["r", "r+", "w"] + _node_attrs: Optional[dict[str, AttributeType]] = None _edge_attrs: Optional[dict[str, AttributeType]] = None def __init__( self, - position_attributes: list[str], mode: str = "r+", + position_attribute: Optional[str] = None, directed: Optional[bool] = None, total_roi: Optional[Roi] = None, - nodes_table: str = "nodes", - edges_table: str = "edges", + nodes_table: Optional[str] = None, + edges_table: Optional[str] = None, endpoint_names: Optional[list[str]] = None, node_attrs: Optional[dict[str, AttributeType]] = None, edge_attrs: Optional[dict[str, AttributeType]] = None, ): - self.position_attributes = position_attributes - self.ndim = len(self.position_attributes) + assert mode in self.valid_modes, f"Mode '{mode}' not in allowed modes {self.valid_modes}" self.mode = mode - self.directed = directed - self.total_roi = total_roi - self.nodes_table_name = nodes_table - self.edges_table_name = edges_table - self.endpoint_names = ["u", "v"] if endpoint_names is None else endpoint_names - self._node_attrs = node_attrs - self._edge_attrs = edge_attrs + if mode in self.read_modes: + + self.position_attribute = position_attribute + self.directed = directed + self.total_roi = total_roi + self.nodes_table_name = nodes_table + self.edges_table_name = edges_table + self.endpoint_names = endpoint_names + self._node_attrs = node_attrs + self._edge_attrs = edge_attrs + self.ndims = None # to be read from metadata + + metadata = self._read_metadata() + self.__load_metadata(metadata) + + if mode in self.create_modes: + + # this is where we populate default values for the DB creation + + assert node_attrs is not None, ( + "For DB creation (mode 'w'), node_attrs is a required " + "argument and needs to contain at least the type definition " + "for the position attribute" + ) + + def get(value, default): + return value if value is not None else default + + self.position_attribute = get(position_attribute, "position") + + assert self.position_attribute in node_attrs, ( + "No type information for position attribute " + f"'{self.position_attribute}' in 'node_attrs'" + ) - if mode == "w": + position_type = node_attrs[self.position_attribute] + if isinstance(position_type, Array): + self.ndims = position_type.size + assert self.ndims > 1, ( + "Don't use Arrays of size 1 for the position, use the " + "scalar type directly instead (i.e., 'float' instead of " + "'Array(float, 1)'." + ) + # if ndims == 1, we know that we have a single scalar now + else: + self.ndims = 1 + + self.directed = get(directed, False) + self.total_roi = get( + total_roi, + Roi((None,) * self.ndims, (None,) * self.ndims)) + self.nodes_table_name = get(nodes_table, "nodes") + self.edges_table_name = get(edges_table, "edges") + self.endpoint_names = get(endpoint_names, ["u", "v"]) + self._node_attrs = node_attrs # no default, needs to be given + self._edge_attrs = get(edge_attrs, {}) + + # delete previous DB, if exists self._drop_tables() - self._create_tables() - self.__init_metadata() + # create new DB + self._create_tables() + + # store metadata + metadata = self.__create_metadata() + self._store_metadata(metadata) @abstractmethod def _drop_tables(self) -> None: @@ -123,6 +185,22 @@ def _update_query(self, query, commit=True) -> None: def _commit(self) -> None: pass + def _node_attrs_to_columns(self, attrs): + # default: each attribute maps to its own column + return attrs + + def _columns_to_node_attrs(self, columns, attrs): + # default: each column maps to one attribute + return columns + + def _edge_attrs_to_columns(self, attrs): + # default: each attribute maps to its own column + return attrs + + def _columns_to_edge_attrs(self, columns, attrs): + # default: each column maps to one attribute + return columns + def read_graph( self, roi: Optional[Roi] = None, @@ -230,9 +308,17 @@ def read_nodes( ) -> list[dict[str, Any]]: """Return a list of nodes within roi.""" + # attributes to read + read_attrs = list(self.node_attrs.keys()) if read_attrs is None else read_attrs + + # corresponding column naes + read_columns = ["id"] + self._node_attrs_to_columns(read_attrs) + read_attrs = ["id"] + read_attrs + read_attrs_query = ", ".join(read_columns) + logger.debug("Reading nodes in roi %s" % roi) select_statement = ( - f"SELECT * FROM {self.nodes_table_name} " + f"SELECT {read_attrs_query} FROM {self.nodes_table_name} " + (self.__roi_query(roi) if roi is not None else "") + ( f" {'WHERE' if roi is None else 'AND'} " @@ -242,27 +328,26 @@ def read_nodes( ) ) - read_attrs = ( - ["id"] - + self.position_attributes - + (list(self.node_attrs.keys()) if read_attrs is None else read_attrs) - ) attr_filter = attr_filter if attr_filter is not None else {} for k, v in attr_filter.items(): select_statement += f" AND {k}={self.__convert_to_sql(v)}" nodes = [ - { - key: val - for key, val in zip( - ["id"] + self.position_attributes + list(self.node_attrs.keys()), - values, - ) - if key in read_attrs and val is not None - } + self._columns_to_node_attrs( + { + key: val + for key, val in zip(read_columns, values) + }, + read_attrs + ) for values in self._select_query(select_statement) ] + for values in self._select_query(select_statement): + print(values) + print(self.node_attrs.keys()) + print(nodes) + return nodes def num_nodes(self, roi: Roi) -> int: @@ -348,8 +433,8 @@ def write_edges( if roi is None: roi = Roi( - (None,) * len(self.position_attributes), - (None,) * len(self.position_attributes), + (None,) * self.ndims, + (None,) * self.ndims, ) values = [] @@ -438,7 +523,7 @@ def write_nodes( logger.debug("Writing nodes in %s", roi) attrs = attributes if attributes is not None else list(self.node_attrs.keys()) - columns = ("id",) + tuple(self.position_attributes) + tuple(attrs) + columns = ("id",) + tuple(attrs) values = [] for node_id, data in nodes.items(): @@ -446,11 +531,9 @@ def write_nodes( pos = self.__get_node_pos(data) if roi is not None and not roi.contains(pos): continue - for i, position_attribute in enumerate(self.position_attributes): - data[position_attribute] = pos[i] values.append( [node_id] - + [data.get(attr, None) for attr in self.position_attributes + attrs] + + [data.get(attr, None) for attr in attrs] ) if len(values) == 0: @@ -498,83 +581,77 @@ def update_nodes( self._commit() - def __init_metadata(self): - metadata = self._read_metadata() - - if metadata: - self.__check_metadata(metadata) - else: - metadata = self.__create_metadata() - self._store_metadata(metadata) - def __create_metadata(self): """Sets the metadata in the meta collection to the provided values""" - if not self.directed: - # default is False - self.directed = self.directed if self.directed is not None else False - if not self.total_roi: - # default is an unbounded roi - self.total_roi = Roi( - (None,) * len(self.position_attributes), - (None,) * len(self.position_attributes), - ) - metadata = { + "position_attribute": self.position_attribute, "directed": self.directed, "total_roi_offset": self.total_roi.offset, "total_roi_shape": self.total_roi.shape, + "nodes_table_name": self.nodes_table_name, + "edges_table_name": self.edges_table_name, + "endpoint_names": self.endpoint_names, "node_attrs": {k: type_to_str(v) for k, v in self.node_attrs.items()}, "edge_attrs": {k: type_to_str(v) for k, v in self.edge_attrs.items()}, + "ndims": self.ndims, } return metadata - def __check_metadata(self, metadata): - """Checks if the provided metadata matches the existing - metadata in the meta collection""" - - if self.directed is not None and metadata["directed"] != self.directed: - raise ValueError( - ( - "Input parameter directed={} does not match" - "directed value {} already in stored metadata" - ).format(self.directed, metadata["directed"]) - ) - elif self.directed is None: - self.directed = metadata["directed"] - if self.total_roi is not None: - if self.total_roi.get_offset() != metadata["total_roi_offset"]: - raise ValueError( - ( - "Input total_roi offset {} does not match" - "total_roi offset {} already stored in metadata" - ).format(self.total_roi.get_offset(), metadata["total_roi_offset"]) - ) - if self.total_roi.get_shape() != metadata["total_roi_shape"]: - raise ValueError( - ( - "Input total_roi shape {} does not match" - "total_roi shape {} already stored in metadata" - ).format(self.total_roi.get_shape(), metadata["total_roi_shape"]) + def __load_metadata(self, metadata): + """Load the provided metadata into this object's attributes, check if + it is consistent with already populated fields.""" + + # simple attributes + for attr_name in [ + "position_attribute", + "directed", + "nodes_table_name", + "edges_table_name", + "endpoint_names", + "ndims"]: + + if getattr(self, attr_name) is None: + setattr(self, attr_name, metadata[attr_name]) + else: + value = getattr(self, attr_name) + assert value == metadata[attr_name], ( + f"Attribute {attr_name} is already set to {value} for this " + "object, but disagrees with the stored metadata value of " + f"{metadata[attr_name]}" ) + + # special attributes + + total_roi = Roi(metadata["total_roi_offset"], metadata["total_roi_shape"]) + if self.total_roi is None: + self.total_roi = total_roi else: - self.total_roi = Roi( - metadata["total_roi_offset"], metadata["total_roi_shape"] - ) - metadata["node_attrs"] = {k: eval(v) for k, v in metadata["node_attrs"].items()} - metadata["edge_attrs"] = {k: eval(v) for k, v in metadata["edge_attrs"].items()} - if self._node_attrs is not None: - assert self.node_attrs == metadata["node_attrs"], ( - self.node_attrs, - metadata["node_attrs"], + assert self.total_roi == total_roi, ( + f"Attribute total_roi is already set to {self.total_roi} for " + "this object, but disagrees with the stored metadata value of " + f"{total_roi}" ) + + node_attrs = {k: eval(v) for k, v in metadata["node_attrs"].items()} + edge_attrs = {k: eval(v) for k, v in metadata["edge_attrs"].items()} + if self._node_attrs is None: + self.node_attrs = node_attrs else: - self.node_attrs = metadata["node_attrs"] - if self._edge_attrs is not None: - assert self.edge_attrs == metadata["edge_attrs"] + assert self.node_attrs == node_attrs, ( + f"Attribute node_attrs is already set to {self.node_attrs} for " + "this object, but disagrees with the stored metadata value of " + f"{node_attrs}" + ) + if self._edge_attrs is None: + self.edge_attrs = edge_attrs else: - self.edge_attrs = metadata["edge_attrs"] + assert self.edge_attrs == edge_attrs, ( + f"Attribute edge_attrs is already set to {self.edge_attrs} for " + "this object, but disagrees with the stored metadata value of " + f"{edge_attrs}" + ) def __remove_keys(self, dictionary, keys): """Removes given keys from dictionary.""" @@ -582,9 +659,7 @@ def __remove_keys(self, dictionary, keys): return {k: v for k, v in dictionary.items() if k not in keys} def __get_node_pos(self, n: dict[str, Any]) -> Coordinate: - return Coordinate( - (n.get(pos_attr, None) for pos_attr in self.position_attributes) - ) + return Coordinate(n[self.position_attribute]) def __convert_to_sql(self, x: Any) -> str: if isinstance(x, str): @@ -605,15 +680,16 @@ def __attr_query(self, attrs: dict[str, Any]) -> str: def __roi_query(self, roi: Roi) -> str: query = "WHERE " - for dim, pos_attr in enumerate(self.position_attributes): + pos_attr = self.position_attribute + for dim in range(self.ndims): if dim > 0: query += " AND " if roi.begin[dim] is not None and roi.end[dim] is not None: - query += f"{pos_attr} BETWEEN {roi.begin[dim]} and {roi.end[dim]}" + query += f"{pos_attr}[{dim + 1}] BETWEEN {roi.begin[dim]} and {roi.end[dim]}" elif roi.begin[dim] is not None: - query += f"{pos_attr}>={roi.begin[dim]}" + query += f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" elif roi.begin[dim] is not None: - query += f"{pos_attr}<{roi.end[dim]}" + query += f"{pos_attr}[{dim + 1}]<{roi.end[dim]}" else: query = query[:-5] return query diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/funlib/persistence/graphs/sqlite_graph_database.py index 5d45b4d..4c3760e 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/funlib/persistence/graphs/sqlite_graph_database.py @@ -1,10 +1,12 @@ from .sql_graph_database import SQLGraphDataBase, AttributeType +from .types import Array from funlib.geometry import Roi import logging import sqlite3 import json +import re from pathlib import Path from typing import Optional, Any @@ -15,7 +17,7 @@ class SQLiteGraphDataBase(SQLGraphDataBase): def __init__( self, db_file: Path, - position_attributes: list[str], + position_attribute: str, mode: str = "r+", directed: Optional[bool] = None, total_roi: Optional[Roi] = None, @@ -31,8 +33,8 @@ def __init__( self.cur = self.con.cursor() super().__init__( - position_attributes, mode=mode, + position_attribute=position_attribute, directed=directed, total_roi=total_roi, nodes_table=nodes_table, @@ -42,6 +44,22 @@ def __init__( edge_attrs=edge_attrs, ) + # in SQLite, array types are stored in individual columns + self.node_array_columns = { + attr: [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + for attr, attr_type in self.node_attrs.items() + if isinstance(attr_type, Array) + } + self.edge_array_columns = { + attr: [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + for attr, attr_type in self.edge_attrs.items() + if isinstance(attr_type, Array) + } + def _drop_tables(self) -> None: logger.info( "dropping collections %s, %s", @@ -57,11 +75,13 @@ def _drop_tables(self) -> None: self.meta_collection.unlink() def _create_tables(self) -> None: - position_template = "{pos_attr} REAL not null" - columns = [ - position_template.format(pos_attr=pos_attr) - for pos_attr in self.position_attributes - ] + list(self.node_attrs.keys()) + columns = list(self.node_attrs.keys()) + columns.remove(self.position_attribute) + position_columns = [ + self.position_attribute + f"_{d}" + for d in range(self.ndims) + ] + columns += [ p + " REAL NOT NULL" for p in position_columns ] self.cur.execute( f"CREATE TABLE IF NOT EXISTS " f"{self.nodes_table_name}(" @@ -70,7 +90,7 @@ def _create_tables(self) -> None: ")" ) self.cur.execute( - f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(self.position_attributes)})" + f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(position_columns)})" ) edge_columns = [ f"{self.endpoint_names[0]} INTEGER not null", @@ -95,12 +115,41 @@ def _read_metadata(self) -> Optional[dict[str, Any]]: return json.load(f) def _select_query(self, query): + + # replace array_name[1] with array_name_0 + # ^^^ + # Yes, that's not a typo + # + # If SQL dialects allow array element access, they start counting at 1. + # We don't want that, we start counting at 0 like normal people. + query = re.sub(r'\[(\d+)\]', lambda m: "_" + str(int(m.group(1)) - 1), query) + try: return self.cur.execute(query) except sqlite3.OperationalError as e: raise ValueError(query) from e def _insert_query(self, table, columns, values, fail_if_exists=False, commit=True): + + # explode array attributes into multiple columns + + exploded_values = [] + for row in values: + exploded_columns = [] + exploded_row_values = [] + for column, value in zip(columns, row): + if column in self.node_array_columns: + for c, v in zip(self.node_array_columns[column], value): + exploded_columns.append(c) + exploded_row_values.append(v) + else: + exploded_columns.append(column) + exploded_row_values.append(value) + exploded_values.append(exploded_row_values) + + columns = exploded_columns + values = exploded_values + insert_statement = ( f"INSERT{' OR IGNORE' if not fail_if_exists else ''} INTO {table} " f"({', '.join(columns)}) VALUES ({', '.join(['?'] * len(columns))})" @@ -121,3 +170,53 @@ def _update_query(self, query, commit=True): def _commit(self): self.con.commit() + + def _node_attrs_to_columns(self, attrs): + columns = [] + for attr in attrs: + attr_type = self.node_attrs[attr] + if isinstance(attr_type, Array): + columns += [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + else: + columns.append(attr) + return columns + + def _columns_to_node_attrs(self, columns, query_attrs): + attrs = {} + for attr in query_attrs: + if attr in self.node_array_columns: + value = tuple( + columns[f"{attr}_{d}"] + for d in range(self.node_attrs[attr].size) + ) + else: + value = columns[attr] + attrs[attr] = value + return attrs + + def _edge_attrs_to_columns(self, attrs): + columns = [] + for attr in attrs: + attr_type = self.edge_attrs[attr] + if isinstance(attr_type, Array): + columns += [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + else: + columns.append(attr) + return columns + + def _columns_to_edge_attrs(self, columns, query_attrs): + attrs = {} + for attr in query_attrs: + if attr in self.edge_array_columns: + value = tuple( + columns[f"{attr}_{d}"] + for d in range(self.edge_attrs[attr].size) + ) + else: + value = columns[attr] + attrs[attr] = value + return attrs diff --git a/tests/conftest.py b/tests/conftest.py index 995ce49..7e220e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ def sqlite_provider_factory( ): return SQLiteGraphDataBase( tmpdir / "test_sqlite_graph.db", - position_attributes=["z", "y", "x"], + position_attribute="position", mode=mode, directed=directed, total_roi=total_roi, @@ -32,7 +32,7 @@ def psql_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): return PgSQLGraphDatabase( - position_attributes=["z", "y", "x"], + position_attribute="position", db_name="pytest", mode=mode, directed=directed, diff --git a/tests/test_graph.py b/tests/test_graph.py index 7c24800..1ee986b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,4 +1,5 @@ from funlib.geometry import Roi, Coordinate +from funlib.persistence.graphs import Array import networkx as nx import pytest @@ -6,15 +7,15 @@ def test_graph_filtering(provider_factory): graph_writer = provider_factory( - "w", node_attrs={"selected": bool}, edge_attrs={"selected": bool} + "w", node_attrs={"position": Array(float, 3), "selected": bool}, edge_attrs={"selected": bool} ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_writer[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True) - graph.add_node(42, z=1, y=1, x=1, selected=False) - graph.add_node(23, z=5, y=5, x=5, selected=True) - graph.add_node(57, z=7, y=7, x=7, selected=True) + graph.add_node(2, position=(2, 2, 2), selected=True) + graph.add_node(42, position=(1, 1, 1), selected=False) + graph.add_node(23, position=(5, 5, 5), selected=True) + graph.add_node(57, position=(7, 7, 7), selected=True) graph.add_edge(42, 23, selected=False) graph.add_edge(57, 23, selected=True) graph.add_edge(2, 42, selected=True) @@ -39,7 +40,7 @@ def test_graph_filtering(provider_factory): roi, nodes_filter={"selected": True}, edges_filter={"selected": True} ) nodes_with_position = [ - node for node, data in filtered_subgraph.nodes(data=True) if "z" in data + node for node, data in filtered_subgraph.nodes(data=True) if "position" in data ] assert expected_node_ids == nodes_with_position assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints) @@ -53,16 +54,16 @@ def test_graph_filtering(provider_factory): def test_graph_filtering_complex(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"selected": bool, "test": str}, + node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True, test="test") - graph.add_node(42, z=1, y=1, x=1, selected=False, test="test2") - graph.add_node(23, z=5, y=5, x=5, selected=True, test="test2") - graph.add_node(57, z=7, y=7, x=7, selected=True, test="test") + graph.add_node(2, position=(2, 2, 2), selected=True, test="test") + graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") + graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") + graph.add_node(57, position=(7, 7, 7), selected=True, test="test") graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) @@ -94,7 +95,7 @@ def test_graph_filtering_complex(provider_factory): edges_filter={"selected": True, "a": 100}, ) nodes_with_position = [ - node for node, data in filtered_subgraph.nodes(data=True) if "z" in data + node for node, data in filtered_subgraph.nodes(data=True) if "position" in data ] assert expected_node_ids == nodes_with_position assert len(filtered_subgraph.edges()) == 0 @@ -103,16 +104,16 @@ def test_graph_filtering_complex(provider_factory): def test_graph_read_and_update_specific_attrs(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"selected": bool, "test": str}, + node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int, "c": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True, test="test") - graph.add_node(42, z=1, y=1, x=1, selected=False, test="test2") - graph.add_node(23, z=5, y=5, x=5, selected=True, test="test2") - graph.add_node(57, z=7, y=7, x=7, selected=True, test="test") + graph.add_node(2, position=(2, 2, 2), selected=True, test="test") + graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") + graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") + graph.add_node(57, position=(7, 7, 7), selected=True, test="test") graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) @@ -154,7 +155,7 @@ def test_graph_read_and_update_specific_attrs(provider_factory): def test_graph_read_unbounded_roi(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"selected": bool, "test": str}, + node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) @@ -162,10 +163,10 @@ def test_graph_read_unbounded_roi(provider_factory): graph = graph_provider[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True, test="test") - graph.add_node(42, z=1, y=1, x=1, selected=False, test="test2") - graph.add_node(23, z=5, y=5, x=5, selected=True, test="test2") - graph.add_node(57, z=7, y=7, x=7, selected=True, test="test") + graph.add_node(2, position=(2, 2, 2), selected=True, test="test") + graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") + graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") + graph.add_node(57, position=(7, 7, 7), selected=True, test="test") graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) @@ -196,14 +197,14 @@ def test_graph_read_unbounded_roi(provider_factory): def test_graph_read_meta_values(provider_factory): roi = Roi((0, 0, 0), (10, 10, 10)) - provider_factory("w", True, roi) + provider_factory("w", True, roi, node_attrs={"position": Array(float, 3)}) graph_provider = provider_factory("r", None, None) assert True == graph_provider.directed assert roi == graph_provider.total_roi def test_graph_default_meta_values(provider_factory): - provider = provider_factory("w", False, None) + provider = provider_factory("w", False, None, node_attrs={"position": Array(float, 3)}) assert False == provider.directed assert provider.total_roi is None or provider.total_roi == Roi( (None, None, None), (None, None, None) @@ -215,26 +216,22 @@ def test_graph_default_meta_values(provider_factory): ) -def test_graph_nonmatching_meta_values(provider_factory): - roi = Roi((0, 0, 0), (10, 10, 10)) - roi2 = Roi((1, 0, 0), (10, 10, 10)) - provider_factory("w", True, None) - with pytest.raises(ValueError): - provider_factory("r", False, None) - provider_factory("w", None, roi) - with pytest.raises(ValueError): - provider_factory("r", None, roi2) - - def test_graph_io(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -263,13 +260,20 @@ def test_graph_io(provider_factory): def test_graph_fail_if_exists(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -282,13 +286,20 @@ def test_graph_fail_if_exists(provider_factory): def test_graph_fail_if_not_exists(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -302,19 +313,26 @@ def test_graph_fail_if_not_exists(provider_factory): def test_graph_write_attributes(provider_factory): - graph_provider = provider_factory("w", node_attrs={"swip": str}) + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(int, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=[0, 0, 0]) + graph.add_node(42, position=[1, 1, 1]) + graph.add_node(23, position=[5, 5, 5], swip="swap") + graph.add_node(57, position=[7, 7, 7], zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) graph_provider.write_graph( - graph, write_nodes=True, write_edges=False, node_attrs=["swip"] + graph, write_nodes=True, write_edges=False, node_attrs=["position", "swip"] ) graph_provider.write_edges( @@ -337,17 +355,34 @@ def test_graph_write_attributes(provider_factory): compare_nodes = [ (node_id, data) for node_id, data in compare_nodes if len(data) > 0 ] - assert nodes == compare_nodes + for n, c in zip(nodes, compare_nodes): + assert n[0] == c[0] + for key in n[1]: + assert key in c[1] + v1 = n[1][key] + v2 = c[1][key] + try: + for e1, e2 in zip(v1, v2): + assert e1 == e2 + except: + assert v1 == v2 def test_graph_write_roi(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -373,13 +408,20 @@ def test_graph_write_roi(provider_factory): def test_graph_connected_components(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(57, 23) graph.add_edge(2, 42) try: @@ -404,15 +446,22 @@ def test_graph_connected_components(provider_factory): def test_graph_has_edge(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23)