diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py index ec40146..6c701c5 100644 --- a/funlib/persistence/graphs/pgsql_graph_database.py +++ b/funlib/persistence/graphs/pgsql_graph_database.py @@ -14,7 +14,7 @@ class PgSQLGraphDatabase(SQLGraphDataBase): def __init__( self, - position_attributes: list[str], + position_attribute: str, db_name: str, db_host: str = "localhost", db_user: Optional[str] = None, @@ -62,8 +62,8 @@ def __init__( self.cur = self.connection.cursor() super().__init__( - position_attributes, mode=mode, + position_attribute=position_attribute, directed=directed, total_roi=total_roi, nodes_table=nodes_table, @@ -86,10 +86,8 @@ def _drop_tables(self) -> None: self._commit() def _create_tables(self) -> None: - columns = self.position_attributes + list(self.node_attrs.keys()) - types = [self.__sql_type(float) + " NOT NULL"] * len( - self.position_attributes - ) + list([self.__sql_type(t) for t in self.node_attrs.values()]) + columns = self.node_attrs.keys() + types = [self.__sql_type(t) for t in self.node_attrs.values()] column_types = [f"{c} {t}" for c, t in zip(columns, types)] self.__exec( f"CREATE TABLE IF NOT EXISTS " @@ -100,7 +98,7 @@ def _create_tables(self) -> None: ) self.__exec( f"CREATE INDEX IF NOT EXISTS pos_index ON " - f"{self.nodes_table_name}({','.join(self.position_attributes)})" + f"{self.nodes_table_name}({self.position_attribute})" ) columns = list(self.edge_attrs.keys()) @@ -139,6 +137,7 @@ def _read_metadata(self) -> Optional[dict[str, Any]]: return None def _select_query(self, query) -> Iterable[Any]: + print(query) self.__exec(query) return self.cur diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index 9627623..f3af09b 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -17,16 +17,21 @@ class SQLGraphDataBase(GraphDataBase): """Base class for SQL-based graph databases. - Nodes must have position attributes (set via argument - ``position_attributes``), which will be used for geometric slicing (see + Nodes must have a position attribute (set via argument + ``position_attribute``), which will be used for geometric slicing (see ``__getitem__`` and ``read_graph``). Arguments: - position_attributes (list of ``string``s): + mode (``string``): - The node attributes that contain position information. This will - be used for slicing subgraphs via ``__getitem__``. + Any of ``r`` (read-only), ``r+`` (read and allow modifications), + or ``w`` (create new database, overwrite if exists). + + position_attribute (``string``): + + The node attribute that contains position information. This will be + used for slicing subgraphs via ``__getitem__``. directed (``bool``): @@ -56,38 +61,95 @@ class SQLGraphDataBase(GraphDataBase): The custom attributes to store on each edge. """ + read_modes = ["r", "r+"] + write_modes = ["r+", "w"] + create_modes = ["w"] + valid_modes = ["r", "r+", "w"] + _node_attrs: Optional[dict[str, AttributeType]] = None _edge_attrs: Optional[dict[str, AttributeType]] = None def __init__( self, - position_attributes: list[str], mode: str = "r+", + position_attribute: Optional[str] = None, directed: Optional[bool] = None, total_roi: Optional[Roi] = None, - nodes_table: str = "nodes", - edges_table: str = "edges", + nodes_table: Optional[str] = None, + edges_table: Optional[str] = None, endpoint_names: Optional[list[str]] = None, node_attrs: Optional[dict[str, AttributeType]] = None, edge_attrs: Optional[dict[str, AttributeType]] = None, ): - self.position_attributes = position_attributes - self.ndim = len(self.position_attributes) + assert mode in self.valid_modes, f"Mode '{mode}' not in allowed modes {self.valid_modes}" self.mode = mode - self.directed = directed - self.total_roi = total_roi - self.nodes_table_name = nodes_table - self.edges_table_name = edges_table - self.endpoint_names = ["u", "v"] if endpoint_names is None else endpoint_names - self._node_attrs = node_attrs - self._edge_attrs = edge_attrs + if mode in self.read_modes: + + self.position_attribute = position_attribute + self.directed = directed + self.total_roi = total_roi + self.nodes_table_name = nodes_table + self.edges_table_name = edges_table + self.endpoint_names = endpoint_names + self._node_attrs = node_attrs + self._edge_attrs = edge_attrs + self.ndims = None # to be read from metadata + + metadata = self._read_metadata() + self.__load_metadata(metadata) + + if mode in self.create_modes: + + # this is where we populate default values for the DB creation + + assert node_attrs is not None, ( + "For DB creation (mode 'w'), node_attrs is a required " + "argument and needs to contain at least the type definition " + "for the position attribute" + ) + + def get(value, default): + return value if value is not None else default + + self.position_attribute = get(position_attribute, "position") + + assert self.position_attribute in node_attrs, ( + "No type information for position attribute " + f"'{self.position_attribute}' in 'node_attrs'" + ) - if mode == "w": + position_type = node_attrs[self.position_attribute] + if isinstance(position_type, Array): + self.ndims = position_type.size + assert self.ndims > 1, ( + "Don't use Arrays of size 1 for the position, use the " + "scalar type directly instead (i.e., 'float' instead of " + "'Array(float, 1)'." + ) + # if ndims == 1, we know that we have a single scalar now + else: + self.ndims = 1 + + self.directed = get(directed, False) + self.total_roi = get( + total_roi, + Roi((None,) * self.ndims, (None,) * self.ndims)) + self.nodes_table_name = get(nodes_table, "nodes") + self.edges_table_name = get(edges_table, "edges") + self.endpoint_names = get(endpoint_names, ["u", "v"]) + self._node_attrs = node_attrs # no default, needs to be given + self._edge_attrs = get(edge_attrs, {}) + + # delete previous DB, if exists self._drop_tables() - self._create_tables() - self.__init_metadata() + # create new DB + self._create_tables() + + # store metadata + metadata = self.__create_metadata() + self._store_metadata(metadata) @abstractmethod def _drop_tables(self) -> None: @@ -123,6 +185,22 @@ def _update_query(self, query, commit=True) -> None: def _commit(self) -> None: pass + def _node_attrs_to_columns(self, attrs): + # default: each attribute maps to its own column + return attrs + + def _columns_to_node_attrs(self, columns, attrs): + # default: each column maps to one attribute + return columns + + def _edge_attrs_to_columns(self, attrs): + # default: each attribute maps to its own column + return attrs + + def _columns_to_edge_attrs(self, columns, attrs): + # default: each column maps to one attribute + return columns + def read_graph( self, roi: Optional[Roi] = None, @@ -230,9 +308,17 @@ def read_nodes( ) -> list[dict[str, Any]]: """Return a list of nodes within roi.""" + # attributes to read + read_attrs = list(self.node_attrs.keys()) if read_attrs is None else read_attrs + + # corresponding column naes + read_columns = ["id"] + self._node_attrs_to_columns(read_attrs) + read_attrs = ["id"] + read_attrs + read_attrs_query = ", ".join(read_columns) + logger.debug("Reading nodes in roi %s" % roi) select_statement = ( - f"SELECT * FROM {self.nodes_table_name} " + f"SELECT {read_attrs_query} FROM {self.nodes_table_name} " + (self.__roi_query(roi) if roi is not None else "") + ( f" {'WHERE' if roi is None else 'AND'} " @@ -242,27 +328,26 @@ def read_nodes( ) ) - read_attrs = ( - ["id"] - + self.position_attributes - + (list(self.node_attrs.keys()) if read_attrs is None else read_attrs) - ) attr_filter = attr_filter if attr_filter is not None else {} for k, v in attr_filter.items(): select_statement += f" AND {k}={self.__convert_to_sql(v)}" nodes = [ - { - key: val - for key, val in zip( - ["id"] + self.position_attributes + list(self.node_attrs.keys()), - values, - ) - if key in read_attrs and val is not None - } + self._columns_to_node_attrs( + { + key: val + for key, val in zip(read_columns, values) + }, + read_attrs + ) for values in self._select_query(select_statement) ] + for values in self._select_query(select_statement): + print(values) + print(self.node_attrs.keys()) + print(nodes) + return nodes def num_nodes(self, roi: Roi) -> int: @@ -348,8 +433,8 @@ def write_edges( if roi is None: roi = Roi( - (None,) * len(self.position_attributes), - (None,) * len(self.position_attributes), + (None,) * self.ndims, + (None,) * self.ndims, ) values = [] @@ -438,7 +523,7 @@ def write_nodes( logger.debug("Writing nodes in %s", roi) attrs = attributes if attributes is not None else list(self.node_attrs.keys()) - columns = ("id",) + tuple(self.position_attributes) + tuple(attrs) + columns = ("id",) + tuple(attrs) values = [] for node_id, data in nodes.items(): @@ -446,11 +531,9 @@ def write_nodes( pos = self.__get_node_pos(data) if roi is not None and not roi.contains(pos): continue - for i, position_attribute in enumerate(self.position_attributes): - data[position_attribute] = pos[i] values.append( [node_id] - + [data.get(attr, None) for attr in self.position_attributes + attrs] + + [data.get(attr, None) for attr in attrs] ) if len(values) == 0: @@ -498,83 +581,77 @@ def update_nodes( self._commit() - def __init_metadata(self): - metadata = self._read_metadata() - - if metadata: - self.__check_metadata(metadata) - else: - metadata = self.__create_metadata() - self._store_metadata(metadata) - def __create_metadata(self): """Sets the metadata in the meta collection to the provided values""" - if not self.directed: - # default is False - self.directed = self.directed if self.directed is not None else False - if not self.total_roi: - # default is an unbounded roi - self.total_roi = Roi( - (None,) * len(self.position_attributes), - (None,) * len(self.position_attributes), - ) - metadata = { + "position_attribute": self.position_attribute, "directed": self.directed, "total_roi_offset": self.total_roi.offset, "total_roi_shape": self.total_roi.shape, + "nodes_table_name": self.nodes_table_name, + "edges_table_name": self.edges_table_name, + "endpoint_names": self.endpoint_names, "node_attrs": {k: type_to_str(v) for k, v in self.node_attrs.items()}, "edge_attrs": {k: type_to_str(v) for k, v in self.edge_attrs.items()}, + "ndims": self.ndims, } return metadata - def __check_metadata(self, metadata): - """Checks if the provided metadata matches the existing - metadata in the meta collection""" - - if self.directed is not None and metadata["directed"] != self.directed: - raise ValueError( - ( - "Input parameter directed={} does not match" - "directed value {} already in stored metadata" - ).format(self.directed, metadata["directed"]) - ) - elif self.directed is None: - self.directed = metadata["directed"] - if self.total_roi is not None: - if self.total_roi.get_offset() != metadata["total_roi_offset"]: - raise ValueError( - ( - "Input total_roi offset {} does not match" - "total_roi offset {} already stored in metadata" - ).format(self.total_roi.get_offset(), metadata["total_roi_offset"]) - ) - if self.total_roi.get_shape() != metadata["total_roi_shape"]: - raise ValueError( - ( - "Input total_roi shape {} does not match" - "total_roi shape {} already stored in metadata" - ).format(self.total_roi.get_shape(), metadata["total_roi_shape"]) + def __load_metadata(self, metadata): + """Load the provided metadata into this object's attributes, check if + it is consistent with already populated fields.""" + + # simple attributes + for attr_name in [ + "position_attribute", + "directed", + "nodes_table_name", + "edges_table_name", + "endpoint_names", + "ndims"]: + + if getattr(self, attr_name) is None: + setattr(self, attr_name, metadata[attr_name]) + else: + value = getattr(self, attr_name) + assert value == metadata[attr_name], ( + f"Attribute {attr_name} is already set to {value} for this " + "object, but disagrees with the stored metadata value of " + f"{metadata[attr_name]}" ) + + # special attributes + + total_roi = Roi(metadata["total_roi_offset"], metadata["total_roi_shape"]) + if self.total_roi is None: + self.total_roi = total_roi else: - self.total_roi = Roi( - metadata["total_roi_offset"], metadata["total_roi_shape"] - ) - metadata["node_attrs"] = {k: eval(v) for k, v in metadata["node_attrs"].items()} - metadata["edge_attrs"] = {k: eval(v) for k, v in metadata["edge_attrs"].items()} - if self._node_attrs is not None: - assert self.node_attrs == metadata["node_attrs"], ( - self.node_attrs, - metadata["node_attrs"], + assert self.total_roi == total_roi, ( + f"Attribute total_roi is already set to {self.total_roi} for " + "this object, but disagrees with the stored metadata value of " + f"{total_roi}" ) + + node_attrs = {k: eval(v) for k, v in metadata["node_attrs"].items()} + edge_attrs = {k: eval(v) for k, v in metadata["edge_attrs"].items()} + if self._node_attrs is None: + self.node_attrs = node_attrs else: - self.node_attrs = metadata["node_attrs"] - if self._edge_attrs is not None: - assert self.edge_attrs == metadata["edge_attrs"] + assert self.node_attrs == node_attrs, ( + f"Attribute node_attrs is already set to {self.node_attrs} for " + "this object, but disagrees with the stored metadata value of " + f"{node_attrs}" + ) + if self._edge_attrs is None: + self.edge_attrs = edge_attrs else: - self.edge_attrs = metadata["edge_attrs"] + assert self.edge_attrs == edge_attrs, ( + f"Attribute edge_attrs is already set to {self.edge_attrs} for " + "this object, but disagrees with the stored metadata value of " + f"{edge_attrs}" + ) def __remove_keys(self, dictionary, keys): """Removes given keys from dictionary.""" @@ -582,9 +659,7 @@ def __remove_keys(self, dictionary, keys): return {k: v for k, v in dictionary.items() if k not in keys} def __get_node_pos(self, n: dict[str, Any]) -> Coordinate: - return Coordinate( - (n.get(pos_attr, None) for pos_attr in self.position_attributes) - ) + return Coordinate(n[self.position_attribute]) def __convert_to_sql(self, x: Any) -> str: if isinstance(x, str): @@ -605,15 +680,16 @@ def __attr_query(self, attrs: dict[str, Any]) -> str: def __roi_query(self, roi: Roi) -> str: query = "WHERE " - for dim, pos_attr in enumerate(self.position_attributes): + pos_attr = self.position_attribute + for dim in range(self.ndims): if dim > 0: query += " AND " if roi.begin[dim] is not None and roi.end[dim] is not None: - query += f"{pos_attr} BETWEEN {roi.begin[dim]} and {roi.end[dim]}" + query += f"{pos_attr}[{dim + 1}] BETWEEN {roi.begin[dim]} and {roi.end[dim]}" elif roi.begin[dim] is not None: - query += f"{pos_attr}>={roi.begin[dim]}" + query += f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" elif roi.begin[dim] is not None: - query += f"{pos_attr}<{roi.end[dim]}" + query += f"{pos_attr}[{dim + 1}]<{roi.end[dim]}" else: query = query[:-5] return query diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/funlib/persistence/graphs/sqlite_graph_database.py index 5d45b4d..4c3760e 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/funlib/persistence/graphs/sqlite_graph_database.py @@ -1,10 +1,12 @@ from .sql_graph_database import SQLGraphDataBase, AttributeType +from .types import Array from funlib.geometry import Roi import logging import sqlite3 import json +import re from pathlib import Path from typing import Optional, Any @@ -15,7 +17,7 @@ class SQLiteGraphDataBase(SQLGraphDataBase): def __init__( self, db_file: Path, - position_attributes: list[str], + position_attribute: str, mode: str = "r+", directed: Optional[bool] = None, total_roi: Optional[Roi] = None, @@ -31,8 +33,8 @@ def __init__( self.cur = self.con.cursor() super().__init__( - position_attributes, mode=mode, + position_attribute=position_attribute, directed=directed, total_roi=total_roi, nodes_table=nodes_table, @@ -42,6 +44,22 @@ def __init__( edge_attrs=edge_attrs, ) + # in SQLite, array types are stored in individual columns + self.node_array_columns = { + attr: [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + for attr, attr_type in self.node_attrs.items() + if isinstance(attr_type, Array) + } + self.edge_array_columns = { + attr: [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + for attr, attr_type in self.edge_attrs.items() + if isinstance(attr_type, Array) + } + def _drop_tables(self) -> None: logger.info( "dropping collections %s, %s", @@ -57,11 +75,13 @@ def _drop_tables(self) -> None: self.meta_collection.unlink() def _create_tables(self) -> None: - position_template = "{pos_attr} REAL not null" - columns = [ - position_template.format(pos_attr=pos_attr) - for pos_attr in self.position_attributes - ] + list(self.node_attrs.keys()) + columns = list(self.node_attrs.keys()) + columns.remove(self.position_attribute) + position_columns = [ + self.position_attribute + f"_{d}" + for d in range(self.ndims) + ] + columns += [ p + " REAL NOT NULL" for p in position_columns ] self.cur.execute( f"CREATE TABLE IF NOT EXISTS " f"{self.nodes_table_name}(" @@ -70,7 +90,7 @@ def _create_tables(self) -> None: ")" ) self.cur.execute( - f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(self.position_attributes)})" + f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(position_columns)})" ) edge_columns = [ f"{self.endpoint_names[0]} INTEGER not null", @@ -95,12 +115,41 @@ def _read_metadata(self) -> Optional[dict[str, Any]]: return json.load(f) def _select_query(self, query): + + # replace array_name[1] with array_name_0 + # ^^^ + # Yes, that's not a typo + # + # If SQL dialects allow array element access, they start counting at 1. + # We don't want that, we start counting at 0 like normal people. + query = re.sub(r'\[(\d+)\]', lambda m: "_" + str(int(m.group(1)) - 1), query) + try: return self.cur.execute(query) except sqlite3.OperationalError as e: raise ValueError(query) from e def _insert_query(self, table, columns, values, fail_if_exists=False, commit=True): + + # explode array attributes into multiple columns + + exploded_values = [] + for row in values: + exploded_columns = [] + exploded_row_values = [] + for column, value in zip(columns, row): + if column in self.node_array_columns: + for c, v in zip(self.node_array_columns[column], value): + exploded_columns.append(c) + exploded_row_values.append(v) + else: + exploded_columns.append(column) + exploded_row_values.append(value) + exploded_values.append(exploded_row_values) + + columns = exploded_columns + values = exploded_values + insert_statement = ( f"INSERT{' OR IGNORE' if not fail_if_exists else ''} INTO {table} " f"({', '.join(columns)}) VALUES ({', '.join(['?'] * len(columns))})" @@ -121,3 +170,53 @@ def _update_query(self, query, commit=True): def _commit(self): self.con.commit() + + def _node_attrs_to_columns(self, attrs): + columns = [] + for attr in attrs: + attr_type = self.node_attrs[attr] + if isinstance(attr_type, Array): + columns += [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + else: + columns.append(attr) + return columns + + def _columns_to_node_attrs(self, columns, query_attrs): + attrs = {} + for attr in query_attrs: + if attr in self.node_array_columns: + value = tuple( + columns[f"{attr}_{d}"] + for d in range(self.node_attrs[attr].size) + ) + else: + value = columns[attr] + attrs[attr] = value + return attrs + + def _edge_attrs_to_columns(self, attrs): + columns = [] + for attr in attrs: + attr_type = self.edge_attrs[attr] + if isinstance(attr_type, Array): + columns += [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + else: + columns.append(attr) + return columns + + def _columns_to_edge_attrs(self, columns, query_attrs): + attrs = {} + for attr in query_attrs: + if attr in self.edge_array_columns: + value = tuple( + columns[f"{attr}_{d}"] + for d in range(self.edge_attrs[attr].size) + ) + else: + value = columns[attr] + attrs[attr] = value + return attrs diff --git a/tests/conftest.py b/tests/conftest.py index 995ce49..7e220e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ def sqlite_provider_factory( ): return SQLiteGraphDataBase( tmpdir / "test_sqlite_graph.db", - position_attributes=["z", "y", "x"], + position_attribute="position", mode=mode, directed=directed, total_roi=total_roi, @@ -32,7 +32,7 @@ def psql_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): return PgSQLGraphDatabase( - position_attributes=["z", "y", "x"], + position_attribute="position", db_name="pytest", mode=mode, directed=directed, diff --git a/tests/test_graph.py b/tests/test_graph.py index 7c24800..1ee986b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,4 +1,5 @@ from funlib.geometry import Roi, Coordinate +from funlib.persistence.graphs import Array import networkx as nx import pytest @@ -6,15 +7,15 @@ def test_graph_filtering(provider_factory): graph_writer = provider_factory( - "w", node_attrs={"selected": bool}, edge_attrs={"selected": bool} + "w", node_attrs={"position": Array(float, 3), "selected": bool}, edge_attrs={"selected": bool} ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_writer[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True) - graph.add_node(42, z=1, y=1, x=1, selected=False) - graph.add_node(23, z=5, y=5, x=5, selected=True) - graph.add_node(57, z=7, y=7, x=7, selected=True) + graph.add_node(2, position=(2, 2, 2), selected=True) + graph.add_node(42, position=(1, 1, 1), selected=False) + graph.add_node(23, position=(5, 5, 5), selected=True) + graph.add_node(57, position=(7, 7, 7), selected=True) graph.add_edge(42, 23, selected=False) graph.add_edge(57, 23, selected=True) graph.add_edge(2, 42, selected=True) @@ -39,7 +40,7 @@ def test_graph_filtering(provider_factory): roi, nodes_filter={"selected": True}, edges_filter={"selected": True} ) nodes_with_position = [ - node for node, data in filtered_subgraph.nodes(data=True) if "z" in data + node for node, data in filtered_subgraph.nodes(data=True) if "position" in data ] assert expected_node_ids == nodes_with_position assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints) @@ -53,16 +54,16 @@ def test_graph_filtering(provider_factory): def test_graph_filtering_complex(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"selected": bool, "test": str}, + node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True, test="test") - graph.add_node(42, z=1, y=1, x=1, selected=False, test="test2") - graph.add_node(23, z=5, y=5, x=5, selected=True, test="test2") - graph.add_node(57, z=7, y=7, x=7, selected=True, test="test") + graph.add_node(2, position=(2, 2, 2), selected=True, test="test") + graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") + graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") + graph.add_node(57, position=(7, 7, 7), selected=True, test="test") graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) @@ -94,7 +95,7 @@ def test_graph_filtering_complex(provider_factory): edges_filter={"selected": True, "a": 100}, ) nodes_with_position = [ - node for node, data in filtered_subgraph.nodes(data=True) if "z" in data + node for node, data in filtered_subgraph.nodes(data=True) if "position" in data ] assert expected_node_ids == nodes_with_position assert len(filtered_subgraph.edges()) == 0 @@ -103,16 +104,16 @@ def test_graph_filtering_complex(provider_factory): def test_graph_read_and_update_specific_attrs(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"selected": bool, "test": str}, + node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int, "c": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True, test="test") - graph.add_node(42, z=1, y=1, x=1, selected=False, test="test2") - graph.add_node(23, z=5, y=5, x=5, selected=True, test="test2") - graph.add_node(57, z=7, y=7, x=7, selected=True, test="test") + graph.add_node(2, position=(2, 2, 2), selected=True, test="test") + graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") + graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") + graph.add_node(57, position=(7, 7, 7), selected=True, test="test") graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) @@ -154,7 +155,7 @@ def test_graph_read_and_update_specific_attrs(provider_factory): def test_graph_read_unbounded_roi(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"selected": bool, "test": str}, + node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) @@ -162,10 +163,10 @@ def test_graph_read_unbounded_roi(provider_factory): graph = graph_provider[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True, test="test") - graph.add_node(42, z=1, y=1, x=1, selected=False, test="test2") - graph.add_node(23, z=5, y=5, x=5, selected=True, test="test2") - graph.add_node(57, z=7, y=7, x=7, selected=True, test="test") + graph.add_node(2, position=(2, 2, 2), selected=True, test="test") + graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") + graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") + graph.add_node(57, position=(7, 7, 7), selected=True, test="test") graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) @@ -196,14 +197,14 @@ def test_graph_read_unbounded_roi(provider_factory): def test_graph_read_meta_values(provider_factory): roi = Roi((0, 0, 0), (10, 10, 10)) - provider_factory("w", True, roi) + provider_factory("w", True, roi, node_attrs={"position": Array(float, 3)}) graph_provider = provider_factory("r", None, None) assert True == graph_provider.directed assert roi == graph_provider.total_roi def test_graph_default_meta_values(provider_factory): - provider = provider_factory("w", False, None) + provider = provider_factory("w", False, None, node_attrs={"position": Array(float, 3)}) assert False == provider.directed assert provider.total_roi is None or provider.total_roi == Roi( (None, None, None), (None, None, None) @@ -215,26 +216,22 @@ def test_graph_default_meta_values(provider_factory): ) -def test_graph_nonmatching_meta_values(provider_factory): - roi = Roi((0, 0, 0), (10, 10, 10)) - roi2 = Roi((1, 0, 0), (10, 10, 10)) - provider_factory("w", True, None) - with pytest.raises(ValueError): - provider_factory("r", False, None) - provider_factory("w", None, roi) - with pytest.raises(ValueError): - provider_factory("r", None, roi2) - - def test_graph_io(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -263,13 +260,20 @@ def test_graph_io(provider_factory): def test_graph_fail_if_exists(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -282,13 +286,20 @@ def test_graph_fail_if_exists(provider_factory): def test_graph_fail_if_not_exists(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -302,19 +313,26 @@ def test_graph_fail_if_not_exists(provider_factory): def test_graph_write_attributes(provider_factory): - graph_provider = provider_factory("w", node_attrs={"swip": str}) + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(int, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=[0, 0, 0]) + graph.add_node(42, position=[1, 1, 1]) + graph.add_node(23, position=[5, 5, 5], swip="swap") + graph.add_node(57, position=[7, 7, 7], zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) graph_provider.write_graph( - graph, write_nodes=True, write_edges=False, node_attrs=["swip"] + graph, write_nodes=True, write_edges=False, node_attrs=["position", "swip"] ) graph_provider.write_edges( @@ -337,17 +355,34 @@ def test_graph_write_attributes(provider_factory): compare_nodes = [ (node_id, data) for node_id, data in compare_nodes if len(data) > 0 ] - assert nodes == compare_nodes + for n, c in zip(nodes, compare_nodes): + assert n[0] == c[0] + for key in n[1]: + assert key in c[1] + v1 = n[1][key] + v2 = c[1][key] + try: + for e1, e2 in zip(v1, v2): + assert e1 == e2 + except: + assert v1 == v2 def test_graph_write_roi(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -373,13 +408,20 @@ def test_graph_write_roi(provider_factory): def test_graph_connected_components(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(57, 23) graph.add_edge(2, 42) try: @@ -404,15 +446,22 @@ def test_graph_connected_components(provider_factory): def test_graph_has_edge(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23)