From 0273ee99d72f3dd993b1b467bb2ae4cd4a8ed80b Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Thu, 7 Mar 2024 10:07:17 -0500 Subject: [PATCH 01/11] Add support for array data types in GraphDBs --- funlib/persistence/graphs/__init__.py | 1 + funlib/persistence/graphs/graph_database.py | 8 +++++-- .../graphs/pgsql_graph_database.py | 6 +++++ .../persistence/graphs/sql_graph_database.py | 23 ++++++++++--------- .../graphs/sqlite_graph_database.py | 6 ++--- funlib/persistence/graphs/types.py | 14 +++++++++++ 6 files changed, 42 insertions(+), 16 deletions(-) create mode 100644 funlib/persistence/graphs/types.py diff --git a/funlib/persistence/graphs/__init__.py b/funlib/persistence/graphs/__init__.py index 6b767a2..f4627c1 100644 --- a/funlib/persistence/graphs/__init__.py +++ b/funlib/persistence/graphs/__init__.py @@ -1,2 +1,3 @@ from .sqlite_graph_database import SQLiteGraphDataBase # noqa from .pgsql_graph_database import PgSQLGraphDatabase # noqa +from .types import Array # noqa diff --git a/funlib/persistence/graphs/graph_database.py b/funlib/persistence/graphs/graph_database.py index 9716ea9..b452f4f 100644 --- a/funlib/persistence/graphs/graph_database.py +++ b/funlib/persistence/graphs/graph_database.py @@ -1,5 +1,6 @@ from networkx import Graph from funlib.geometry import Roi +from .types import Array import logging from abc import ABC, abstractmethod @@ -9,6 +10,9 @@ logger = logging.getLogger(__name__) +AttributeType = type | str | Array + + class GraphDataBase(ABC): """ Interface for graph databases that supports slicing to retrieve @@ -33,7 +37,7 @@ def __getitem__(self, roi) -> Graph: @property @abstractmethod - def node_attrs(self) -> dict[str, type]: + def node_attrs(self) -> dict[str, AttributeType]: """ Return the node attributes supported by the database. """ @@ -41,7 +45,7 @@ def node_attrs(self) -> dict[str, type]: @property @abstractmethod - def edge_attrs(self) -> dict[str, type]: + def edge_attrs(self) -> dict[str, AttributeType]: """ Return the edge attributes supported by the database. """ diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py index 59de797..ec40146 100644 --- a/funlib/persistence/graphs/pgsql_graph_database.py +++ b/funlib/persistence/graphs/pgsql_graph_database.py @@ -1,10 +1,12 @@ from .sql_graph_database import SQLGraphDataBase +from .types import Array from funlib.geometry import Roi import logging import psycopg2 import json from typing import Optional, Any, Iterable +from collections.abc import Iterable logger = logging.getLogger(__name__) @@ -177,12 +179,16 @@ def __exec(self, query): def __sql_value(self, value): if isinstance(value, str): return f"'{value}'" + if isinstance(value, Iterable): + return f"array[{','.join([self.__sql_value(v) for v in value])}]" elif value is None: return "NULL" else: return str(value) def __sql_type(self, type): + if isinstance(type, Array): + return self.__sql_type(type.dtype) + f"[{type.size}]" try: return {bool: "BOOLEAN", int: "INTEGER", str: "VARCHAR", float: "REAL"}[ type diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index ede8a03..9627623 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -1,4 +1,5 @@ -from .graph_database import GraphDataBase +from .graph_database import GraphDataBase, AttributeType +from .types import Array, type_to_str from funlib.geometry import Coordinate from funlib.geometry import Roi @@ -55,8 +56,8 @@ class SQLGraphDataBase(GraphDataBase): The custom attributes to store on each edge. """ - _node_attrs: Optional[dict[str, type]] = None - _edge_attrs: Optional[dict[str, type]] = None + _node_attrs: Optional[dict[str, AttributeType]] = None + _edge_attrs: Optional[dict[str, AttributeType]] = None def __init__( self, @@ -67,8 +68,8 @@ def __init__( nodes_table: str = "nodes", edges_table: str = "edges", endpoint_names: Optional[list[str]] = None, - node_attrs: Optional[dict[str, type]] = None, - edge_attrs: Optional[dict[str, type]] = None, + node_attrs: Optional[dict[str, AttributeType]] = None, + edge_attrs: Optional[dict[str, AttributeType]] = None, ): self.position_attributes = position_attributes self.ndim = len(self.position_attributes) @@ -205,19 +206,19 @@ def write_graph( ) @property - def node_attrs(self) -> dict[str, type]: + def node_attrs(self) -> dict[str, AttributeType]: return self._node_attrs if self._node_attrs is not None else {} @node_attrs.setter - def node_attrs(self, value: dict[str, type]) -> None: + def node_attrs(self, value: dict[str, AttributeType]) -> None: self._node_attrs = value @property - def edge_attrs(self) -> dict[str, type]: + def edge_attrs(self) -> dict[str, AttributeType]: return self._edge_attrs if self._edge_attrs is not None else {} @edge_attrs.setter - def edge_attrs(self, value: dict[str, type]) -> None: + def edge_attrs(self, value: dict[str, AttributeType]) -> None: self._edge_attrs = value def read_nodes( @@ -523,8 +524,8 @@ def __create_metadata(self): "directed": self.directed, "total_roi_offset": self.total_roi.offset, "total_roi_shape": self.total_roi.shape, - "node_attrs": {k: v.__name__ for k, v in self.node_attrs.items()}, - "edge_attrs": {k: v.__name__ for k, v in self.edge_attrs.items()}, + "node_attrs": {k: type_to_str(v) for k, v in self.node_attrs.items()}, + "edge_attrs": {k: type_to_str(v) for k, v in self.edge_attrs.items()}, } return metadata diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/funlib/persistence/graphs/sqlite_graph_database.py index 6ed1c3f..5d45b4d 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/funlib/persistence/graphs/sqlite_graph_database.py @@ -1,4 +1,4 @@ -from .sql_graph_database import SQLGraphDataBase +from .sql_graph_database import SQLGraphDataBase, AttributeType from funlib.geometry import Roi @@ -22,8 +22,8 @@ def __init__( nodes_table: str = "nodes", edges_table: str = "edges", endpoint_names: Optional[list[str]] = None, - node_attrs: Optional[dict[str, type]] = None, - edge_attrs: Optional[dict[str, type]] = None, + node_attrs: Optional[dict[str, AttributeType]] = None, + edge_attrs: Optional[dict[str, AttributeType]] = None, ): self.db_file = db_file self.meta_collection = self.db_file.parent / f"{self.db_file.stem}-meta.json" diff --git a/funlib/persistence/graphs/types.py b/funlib/persistence/graphs/types.py new file mode 100644 index 0000000..c49f2d6 --- /dev/null +++ b/funlib/persistence/graphs/types.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + + +@dataclass +class Array: + dtype: type | str + size: int + + +def type_to_str(type): + if isinstance(type, Array): + return f"Array({type_to_str(type.dtype)}, {type.size})" + else: + return type.__name__ From c949e97c206d2df557f81da847ddf6a5dc633f45 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Thu, 7 Mar 2024 17:38:18 -0500 Subject: [PATCH 02/11] Store position attribute as Array type The list of position attributes has now been replaced with a single attribute, which can be an Array. Postgre will store the array directly, SQLite will create one column for each element. This commit also tightens the requirements for optional arguments passed to the graph databases and simplifies the meta-data checks. There is now also documentation on the allowed read/write modes and checks for the passed arguments. --- .../graphs/pgsql_graph_database.py | 13 +- .../persistence/graphs/sql_graph_database.py | 288 +++++++++++------- .../graphs/sqlite_graph_database.py | 115 ++++++- tests/conftest.py | 4 +- tests/test_graph.py | 193 +++++++----- 5 files changed, 418 insertions(+), 195 deletions(-) diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py index ec40146..6c701c5 100644 --- a/funlib/persistence/graphs/pgsql_graph_database.py +++ b/funlib/persistence/graphs/pgsql_graph_database.py @@ -14,7 +14,7 @@ class PgSQLGraphDatabase(SQLGraphDataBase): def __init__( self, - position_attributes: list[str], + position_attribute: str, db_name: str, db_host: str = "localhost", db_user: Optional[str] = None, @@ -62,8 +62,8 @@ def __init__( self.cur = self.connection.cursor() super().__init__( - position_attributes, mode=mode, + position_attribute=position_attribute, directed=directed, total_roi=total_roi, nodes_table=nodes_table, @@ -86,10 +86,8 @@ def _drop_tables(self) -> None: self._commit() def _create_tables(self) -> None: - columns = self.position_attributes + list(self.node_attrs.keys()) - types = [self.__sql_type(float) + " NOT NULL"] * len( - self.position_attributes - ) + list([self.__sql_type(t) for t in self.node_attrs.values()]) + columns = self.node_attrs.keys() + types = [self.__sql_type(t) for t in self.node_attrs.values()] column_types = [f"{c} {t}" for c, t in zip(columns, types)] self.__exec( f"CREATE TABLE IF NOT EXISTS " @@ -100,7 +98,7 @@ def _create_tables(self) -> None: ) self.__exec( f"CREATE INDEX IF NOT EXISTS pos_index ON " - f"{self.nodes_table_name}({','.join(self.position_attributes)})" + f"{self.nodes_table_name}({self.position_attribute})" ) columns = list(self.edge_attrs.keys()) @@ -139,6 +137,7 @@ def _read_metadata(self) -> Optional[dict[str, Any]]: return None def _select_query(self, query) -> Iterable[Any]: + print(query) self.__exec(query) return self.cur diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index 9627623..f3af09b 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -17,16 +17,21 @@ class SQLGraphDataBase(GraphDataBase): """Base class for SQL-based graph databases. - Nodes must have position attributes (set via argument - ``position_attributes``), which will be used for geometric slicing (see + Nodes must have a position attribute (set via argument + ``position_attribute``), which will be used for geometric slicing (see ``__getitem__`` and ``read_graph``). Arguments: - position_attributes (list of ``string``s): + mode (``string``): - The node attributes that contain position information. This will - be used for slicing subgraphs via ``__getitem__``. + Any of ``r`` (read-only), ``r+`` (read and allow modifications), + or ``w`` (create new database, overwrite if exists). + + position_attribute (``string``): + + The node attribute that contains position information. This will be + used for slicing subgraphs via ``__getitem__``. directed (``bool``): @@ -56,38 +61,95 @@ class SQLGraphDataBase(GraphDataBase): The custom attributes to store on each edge. """ + read_modes = ["r", "r+"] + write_modes = ["r+", "w"] + create_modes = ["w"] + valid_modes = ["r", "r+", "w"] + _node_attrs: Optional[dict[str, AttributeType]] = None _edge_attrs: Optional[dict[str, AttributeType]] = None def __init__( self, - position_attributes: list[str], mode: str = "r+", + position_attribute: Optional[str] = None, directed: Optional[bool] = None, total_roi: Optional[Roi] = None, - nodes_table: str = "nodes", - edges_table: str = "edges", + nodes_table: Optional[str] = None, + edges_table: Optional[str] = None, endpoint_names: Optional[list[str]] = None, node_attrs: Optional[dict[str, AttributeType]] = None, edge_attrs: Optional[dict[str, AttributeType]] = None, ): - self.position_attributes = position_attributes - self.ndim = len(self.position_attributes) + assert mode in self.valid_modes, f"Mode '{mode}' not in allowed modes {self.valid_modes}" self.mode = mode - self.directed = directed - self.total_roi = total_roi - self.nodes_table_name = nodes_table - self.edges_table_name = edges_table - self.endpoint_names = ["u", "v"] if endpoint_names is None else endpoint_names - self._node_attrs = node_attrs - self._edge_attrs = edge_attrs + if mode in self.read_modes: + + self.position_attribute = position_attribute + self.directed = directed + self.total_roi = total_roi + self.nodes_table_name = nodes_table + self.edges_table_name = edges_table + self.endpoint_names = endpoint_names + self._node_attrs = node_attrs + self._edge_attrs = edge_attrs + self.ndims = None # to be read from metadata + + metadata = self._read_metadata() + self.__load_metadata(metadata) + + if mode in self.create_modes: + + # this is where we populate default values for the DB creation + + assert node_attrs is not None, ( + "For DB creation (mode 'w'), node_attrs is a required " + "argument and needs to contain at least the type definition " + "for the position attribute" + ) + + def get(value, default): + return value if value is not None else default + + self.position_attribute = get(position_attribute, "position") + + assert self.position_attribute in node_attrs, ( + "No type information for position attribute " + f"'{self.position_attribute}' in 'node_attrs'" + ) - if mode == "w": + position_type = node_attrs[self.position_attribute] + if isinstance(position_type, Array): + self.ndims = position_type.size + assert self.ndims > 1, ( + "Don't use Arrays of size 1 for the position, use the " + "scalar type directly instead (i.e., 'float' instead of " + "'Array(float, 1)'." + ) + # if ndims == 1, we know that we have a single scalar now + else: + self.ndims = 1 + + self.directed = get(directed, False) + self.total_roi = get( + total_roi, + Roi((None,) * self.ndims, (None,) * self.ndims)) + self.nodes_table_name = get(nodes_table, "nodes") + self.edges_table_name = get(edges_table, "edges") + self.endpoint_names = get(endpoint_names, ["u", "v"]) + self._node_attrs = node_attrs # no default, needs to be given + self._edge_attrs = get(edge_attrs, {}) + + # delete previous DB, if exists self._drop_tables() - self._create_tables() - self.__init_metadata() + # create new DB + self._create_tables() + + # store metadata + metadata = self.__create_metadata() + self._store_metadata(metadata) @abstractmethod def _drop_tables(self) -> None: @@ -123,6 +185,22 @@ def _update_query(self, query, commit=True) -> None: def _commit(self) -> None: pass + def _node_attrs_to_columns(self, attrs): + # default: each attribute maps to its own column + return attrs + + def _columns_to_node_attrs(self, columns, attrs): + # default: each column maps to one attribute + return columns + + def _edge_attrs_to_columns(self, attrs): + # default: each attribute maps to its own column + return attrs + + def _columns_to_edge_attrs(self, columns, attrs): + # default: each column maps to one attribute + return columns + def read_graph( self, roi: Optional[Roi] = None, @@ -230,9 +308,17 @@ def read_nodes( ) -> list[dict[str, Any]]: """Return a list of nodes within roi.""" + # attributes to read + read_attrs = list(self.node_attrs.keys()) if read_attrs is None else read_attrs + + # corresponding column naes + read_columns = ["id"] + self._node_attrs_to_columns(read_attrs) + read_attrs = ["id"] + read_attrs + read_attrs_query = ", ".join(read_columns) + logger.debug("Reading nodes in roi %s" % roi) select_statement = ( - f"SELECT * FROM {self.nodes_table_name} " + f"SELECT {read_attrs_query} FROM {self.nodes_table_name} " + (self.__roi_query(roi) if roi is not None else "") + ( f" {'WHERE' if roi is None else 'AND'} " @@ -242,27 +328,26 @@ def read_nodes( ) ) - read_attrs = ( - ["id"] - + self.position_attributes - + (list(self.node_attrs.keys()) if read_attrs is None else read_attrs) - ) attr_filter = attr_filter if attr_filter is not None else {} for k, v in attr_filter.items(): select_statement += f" AND {k}={self.__convert_to_sql(v)}" nodes = [ - { - key: val - for key, val in zip( - ["id"] + self.position_attributes + list(self.node_attrs.keys()), - values, - ) - if key in read_attrs and val is not None - } + self._columns_to_node_attrs( + { + key: val + for key, val in zip(read_columns, values) + }, + read_attrs + ) for values in self._select_query(select_statement) ] + for values in self._select_query(select_statement): + print(values) + print(self.node_attrs.keys()) + print(nodes) + return nodes def num_nodes(self, roi: Roi) -> int: @@ -348,8 +433,8 @@ def write_edges( if roi is None: roi = Roi( - (None,) * len(self.position_attributes), - (None,) * len(self.position_attributes), + (None,) * self.ndims, + (None,) * self.ndims, ) values = [] @@ -438,7 +523,7 @@ def write_nodes( logger.debug("Writing nodes in %s", roi) attrs = attributes if attributes is not None else list(self.node_attrs.keys()) - columns = ("id",) + tuple(self.position_attributes) + tuple(attrs) + columns = ("id",) + tuple(attrs) values = [] for node_id, data in nodes.items(): @@ -446,11 +531,9 @@ def write_nodes( pos = self.__get_node_pos(data) if roi is not None and not roi.contains(pos): continue - for i, position_attribute in enumerate(self.position_attributes): - data[position_attribute] = pos[i] values.append( [node_id] - + [data.get(attr, None) for attr in self.position_attributes + attrs] + + [data.get(attr, None) for attr in attrs] ) if len(values) == 0: @@ -498,83 +581,77 @@ def update_nodes( self._commit() - def __init_metadata(self): - metadata = self._read_metadata() - - if metadata: - self.__check_metadata(metadata) - else: - metadata = self.__create_metadata() - self._store_metadata(metadata) - def __create_metadata(self): """Sets the metadata in the meta collection to the provided values""" - if not self.directed: - # default is False - self.directed = self.directed if self.directed is not None else False - if not self.total_roi: - # default is an unbounded roi - self.total_roi = Roi( - (None,) * len(self.position_attributes), - (None,) * len(self.position_attributes), - ) - metadata = { + "position_attribute": self.position_attribute, "directed": self.directed, "total_roi_offset": self.total_roi.offset, "total_roi_shape": self.total_roi.shape, + "nodes_table_name": self.nodes_table_name, + "edges_table_name": self.edges_table_name, + "endpoint_names": self.endpoint_names, "node_attrs": {k: type_to_str(v) for k, v in self.node_attrs.items()}, "edge_attrs": {k: type_to_str(v) for k, v in self.edge_attrs.items()}, + "ndims": self.ndims, } return metadata - def __check_metadata(self, metadata): - """Checks if the provided metadata matches the existing - metadata in the meta collection""" - - if self.directed is not None and metadata["directed"] != self.directed: - raise ValueError( - ( - "Input parameter directed={} does not match" - "directed value {} already in stored metadata" - ).format(self.directed, metadata["directed"]) - ) - elif self.directed is None: - self.directed = metadata["directed"] - if self.total_roi is not None: - if self.total_roi.get_offset() != metadata["total_roi_offset"]: - raise ValueError( - ( - "Input total_roi offset {} does not match" - "total_roi offset {} already stored in metadata" - ).format(self.total_roi.get_offset(), metadata["total_roi_offset"]) - ) - if self.total_roi.get_shape() != metadata["total_roi_shape"]: - raise ValueError( - ( - "Input total_roi shape {} does not match" - "total_roi shape {} already stored in metadata" - ).format(self.total_roi.get_shape(), metadata["total_roi_shape"]) + def __load_metadata(self, metadata): + """Load the provided metadata into this object's attributes, check if + it is consistent with already populated fields.""" + + # simple attributes + for attr_name in [ + "position_attribute", + "directed", + "nodes_table_name", + "edges_table_name", + "endpoint_names", + "ndims"]: + + if getattr(self, attr_name) is None: + setattr(self, attr_name, metadata[attr_name]) + else: + value = getattr(self, attr_name) + assert value == metadata[attr_name], ( + f"Attribute {attr_name} is already set to {value} for this " + "object, but disagrees with the stored metadata value of " + f"{metadata[attr_name]}" ) + + # special attributes + + total_roi = Roi(metadata["total_roi_offset"], metadata["total_roi_shape"]) + if self.total_roi is None: + self.total_roi = total_roi else: - self.total_roi = Roi( - metadata["total_roi_offset"], metadata["total_roi_shape"] - ) - metadata["node_attrs"] = {k: eval(v) for k, v in metadata["node_attrs"].items()} - metadata["edge_attrs"] = {k: eval(v) for k, v in metadata["edge_attrs"].items()} - if self._node_attrs is not None: - assert self.node_attrs == metadata["node_attrs"], ( - self.node_attrs, - metadata["node_attrs"], + assert self.total_roi == total_roi, ( + f"Attribute total_roi is already set to {self.total_roi} for " + "this object, but disagrees with the stored metadata value of " + f"{total_roi}" ) + + node_attrs = {k: eval(v) for k, v in metadata["node_attrs"].items()} + edge_attrs = {k: eval(v) for k, v in metadata["edge_attrs"].items()} + if self._node_attrs is None: + self.node_attrs = node_attrs else: - self.node_attrs = metadata["node_attrs"] - if self._edge_attrs is not None: - assert self.edge_attrs == metadata["edge_attrs"] + assert self.node_attrs == node_attrs, ( + f"Attribute node_attrs is already set to {self.node_attrs} for " + "this object, but disagrees with the stored metadata value of " + f"{node_attrs}" + ) + if self._edge_attrs is None: + self.edge_attrs = edge_attrs else: - self.edge_attrs = metadata["edge_attrs"] + assert self.edge_attrs == edge_attrs, ( + f"Attribute edge_attrs is already set to {self.edge_attrs} for " + "this object, but disagrees with the stored metadata value of " + f"{edge_attrs}" + ) def __remove_keys(self, dictionary, keys): """Removes given keys from dictionary.""" @@ -582,9 +659,7 @@ def __remove_keys(self, dictionary, keys): return {k: v for k, v in dictionary.items() if k not in keys} def __get_node_pos(self, n: dict[str, Any]) -> Coordinate: - return Coordinate( - (n.get(pos_attr, None) for pos_attr in self.position_attributes) - ) + return Coordinate(n[self.position_attribute]) def __convert_to_sql(self, x: Any) -> str: if isinstance(x, str): @@ -605,15 +680,16 @@ def __attr_query(self, attrs: dict[str, Any]) -> str: def __roi_query(self, roi: Roi) -> str: query = "WHERE " - for dim, pos_attr in enumerate(self.position_attributes): + pos_attr = self.position_attribute + for dim in range(self.ndims): if dim > 0: query += " AND " if roi.begin[dim] is not None and roi.end[dim] is not None: - query += f"{pos_attr} BETWEEN {roi.begin[dim]} and {roi.end[dim]}" + query += f"{pos_attr}[{dim + 1}] BETWEEN {roi.begin[dim]} and {roi.end[dim]}" elif roi.begin[dim] is not None: - query += f"{pos_attr}>={roi.begin[dim]}" + query += f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" elif roi.begin[dim] is not None: - query += f"{pos_attr}<{roi.end[dim]}" + query += f"{pos_attr}[{dim + 1}]<{roi.end[dim]}" else: query = query[:-5] return query diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/funlib/persistence/graphs/sqlite_graph_database.py index 5d45b4d..4c3760e 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/funlib/persistence/graphs/sqlite_graph_database.py @@ -1,10 +1,12 @@ from .sql_graph_database import SQLGraphDataBase, AttributeType +from .types import Array from funlib.geometry import Roi import logging import sqlite3 import json +import re from pathlib import Path from typing import Optional, Any @@ -15,7 +17,7 @@ class SQLiteGraphDataBase(SQLGraphDataBase): def __init__( self, db_file: Path, - position_attributes: list[str], + position_attribute: str, mode: str = "r+", directed: Optional[bool] = None, total_roi: Optional[Roi] = None, @@ -31,8 +33,8 @@ def __init__( self.cur = self.con.cursor() super().__init__( - position_attributes, mode=mode, + position_attribute=position_attribute, directed=directed, total_roi=total_roi, nodes_table=nodes_table, @@ -42,6 +44,22 @@ def __init__( edge_attrs=edge_attrs, ) + # in SQLite, array types are stored in individual columns + self.node_array_columns = { + attr: [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + for attr, attr_type in self.node_attrs.items() + if isinstance(attr_type, Array) + } + self.edge_array_columns = { + attr: [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + for attr, attr_type in self.edge_attrs.items() + if isinstance(attr_type, Array) + } + def _drop_tables(self) -> None: logger.info( "dropping collections %s, %s", @@ -57,11 +75,13 @@ def _drop_tables(self) -> None: self.meta_collection.unlink() def _create_tables(self) -> None: - position_template = "{pos_attr} REAL not null" - columns = [ - position_template.format(pos_attr=pos_attr) - for pos_attr in self.position_attributes - ] + list(self.node_attrs.keys()) + columns = list(self.node_attrs.keys()) + columns.remove(self.position_attribute) + position_columns = [ + self.position_attribute + f"_{d}" + for d in range(self.ndims) + ] + columns += [ p + " REAL NOT NULL" for p in position_columns ] self.cur.execute( f"CREATE TABLE IF NOT EXISTS " f"{self.nodes_table_name}(" @@ -70,7 +90,7 @@ def _create_tables(self) -> None: ")" ) self.cur.execute( - f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(self.position_attributes)})" + f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(position_columns)})" ) edge_columns = [ f"{self.endpoint_names[0]} INTEGER not null", @@ -95,12 +115,41 @@ def _read_metadata(self) -> Optional[dict[str, Any]]: return json.load(f) def _select_query(self, query): + + # replace array_name[1] with array_name_0 + # ^^^ + # Yes, that's not a typo + # + # If SQL dialects allow array element access, they start counting at 1. + # We don't want that, we start counting at 0 like normal people. + query = re.sub(r'\[(\d+)\]', lambda m: "_" + str(int(m.group(1)) - 1), query) + try: return self.cur.execute(query) except sqlite3.OperationalError as e: raise ValueError(query) from e def _insert_query(self, table, columns, values, fail_if_exists=False, commit=True): + + # explode array attributes into multiple columns + + exploded_values = [] + for row in values: + exploded_columns = [] + exploded_row_values = [] + for column, value in zip(columns, row): + if column in self.node_array_columns: + for c, v in zip(self.node_array_columns[column], value): + exploded_columns.append(c) + exploded_row_values.append(v) + else: + exploded_columns.append(column) + exploded_row_values.append(value) + exploded_values.append(exploded_row_values) + + columns = exploded_columns + values = exploded_values + insert_statement = ( f"INSERT{' OR IGNORE' if not fail_if_exists else ''} INTO {table} " f"({', '.join(columns)}) VALUES ({', '.join(['?'] * len(columns))})" @@ -121,3 +170,53 @@ def _update_query(self, query, commit=True): def _commit(self): self.con.commit() + + def _node_attrs_to_columns(self, attrs): + columns = [] + for attr in attrs: + attr_type = self.node_attrs[attr] + if isinstance(attr_type, Array): + columns += [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + else: + columns.append(attr) + return columns + + def _columns_to_node_attrs(self, columns, query_attrs): + attrs = {} + for attr in query_attrs: + if attr in self.node_array_columns: + value = tuple( + columns[f"{attr}_{d}"] + for d in range(self.node_attrs[attr].size) + ) + else: + value = columns[attr] + attrs[attr] = value + return attrs + + def _edge_attrs_to_columns(self, attrs): + columns = [] + for attr in attrs: + attr_type = self.edge_attrs[attr] + if isinstance(attr_type, Array): + columns += [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + else: + columns.append(attr) + return columns + + def _columns_to_edge_attrs(self, columns, query_attrs): + attrs = {} + for attr in query_attrs: + if attr in self.edge_array_columns: + value = tuple( + columns[f"{attr}_{d}"] + for d in range(self.edge_attrs[attr].size) + ) + else: + value = columns[attr] + attrs[attr] = value + return attrs diff --git a/tests/conftest.py b/tests/conftest.py index 995ce49..7e220e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ def sqlite_provider_factory( ): return SQLiteGraphDataBase( tmpdir / "test_sqlite_graph.db", - position_attributes=["z", "y", "x"], + position_attribute="position", mode=mode, directed=directed, total_roi=total_roi, @@ -32,7 +32,7 @@ def psql_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): return PgSQLGraphDatabase( - position_attributes=["z", "y", "x"], + position_attribute="position", db_name="pytest", mode=mode, directed=directed, diff --git a/tests/test_graph.py b/tests/test_graph.py index 7c24800..1ee986b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,4 +1,5 @@ from funlib.geometry import Roi, Coordinate +from funlib.persistence.graphs import Array import networkx as nx import pytest @@ -6,15 +7,15 @@ def test_graph_filtering(provider_factory): graph_writer = provider_factory( - "w", node_attrs={"selected": bool}, edge_attrs={"selected": bool} + "w", node_attrs={"position": Array(float, 3), "selected": bool}, edge_attrs={"selected": bool} ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_writer[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True) - graph.add_node(42, z=1, y=1, x=1, selected=False) - graph.add_node(23, z=5, y=5, x=5, selected=True) - graph.add_node(57, z=7, y=7, x=7, selected=True) + graph.add_node(2, position=(2, 2, 2), selected=True) + graph.add_node(42, position=(1, 1, 1), selected=False) + graph.add_node(23, position=(5, 5, 5), selected=True) + graph.add_node(57, position=(7, 7, 7), selected=True) graph.add_edge(42, 23, selected=False) graph.add_edge(57, 23, selected=True) graph.add_edge(2, 42, selected=True) @@ -39,7 +40,7 @@ def test_graph_filtering(provider_factory): roi, nodes_filter={"selected": True}, edges_filter={"selected": True} ) nodes_with_position = [ - node for node, data in filtered_subgraph.nodes(data=True) if "z" in data + node for node, data in filtered_subgraph.nodes(data=True) if "position" in data ] assert expected_node_ids == nodes_with_position assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints) @@ -53,16 +54,16 @@ def test_graph_filtering(provider_factory): def test_graph_filtering_complex(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"selected": bool, "test": str}, + node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True, test="test") - graph.add_node(42, z=1, y=1, x=1, selected=False, test="test2") - graph.add_node(23, z=5, y=5, x=5, selected=True, test="test2") - graph.add_node(57, z=7, y=7, x=7, selected=True, test="test") + graph.add_node(2, position=(2, 2, 2), selected=True, test="test") + graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") + graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") + graph.add_node(57, position=(7, 7, 7), selected=True, test="test") graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) @@ -94,7 +95,7 @@ def test_graph_filtering_complex(provider_factory): edges_filter={"selected": True, "a": 100}, ) nodes_with_position = [ - node for node, data in filtered_subgraph.nodes(data=True) if "z" in data + node for node, data in filtered_subgraph.nodes(data=True) if "position" in data ] assert expected_node_ids == nodes_with_position assert len(filtered_subgraph.edges()) == 0 @@ -103,16 +104,16 @@ def test_graph_filtering_complex(provider_factory): def test_graph_read_and_update_specific_attrs(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"selected": bool, "test": str}, + node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int, "c": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True, test="test") - graph.add_node(42, z=1, y=1, x=1, selected=False, test="test2") - graph.add_node(23, z=5, y=5, x=5, selected=True, test="test2") - graph.add_node(57, z=7, y=7, x=7, selected=True, test="test") + graph.add_node(2, position=(2, 2, 2), selected=True, test="test") + graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") + graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") + graph.add_node(57, position=(7, 7, 7), selected=True, test="test") graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) @@ -154,7 +155,7 @@ def test_graph_read_and_update_specific_attrs(provider_factory): def test_graph_read_unbounded_roi(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"selected": bool, "test": str}, + node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) @@ -162,10 +163,10 @@ def test_graph_read_unbounded_roi(provider_factory): graph = graph_provider[roi] - graph.add_node(2, z=2, y=2, x=2, selected=True, test="test") - graph.add_node(42, z=1, y=1, x=1, selected=False, test="test2") - graph.add_node(23, z=5, y=5, x=5, selected=True, test="test2") - graph.add_node(57, z=7, y=7, x=7, selected=True, test="test") + graph.add_node(2, position=(2, 2, 2), selected=True, test="test") + graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") + graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") + graph.add_node(57, position=(7, 7, 7), selected=True, test="test") graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) @@ -196,14 +197,14 @@ def test_graph_read_unbounded_roi(provider_factory): def test_graph_read_meta_values(provider_factory): roi = Roi((0, 0, 0), (10, 10, 10)) - provider_factory("w", True, roi) + provider_factory("w", True, roi, node_attrs={"position": Array(float, 3)}) graph_provider = provider_factory("r", None, None) assert True == graph_provider.directed assert roi == graph_provider.total_roi def test_graph_default_meta_values(provider_factory): - provider = provider_factory("w", False, None) + provider = provider_factory("w", False, None, node_attrs={"position": Array(float, 3)}) assert False == provider.directed assert provider.total_roi is None or provider.total_roi == Roi( (None, None, None), (None, None, None) @@ -215,26 +216,22 @@ def test_graph_default_meta_values(provider_factory): ) -def test_graph_nonmatching_meta_values(provider_factory): - roi = Roi((0, 0, 0), (10, 10, 10)) - roi2 = Roi((1, 0, 0), (10, 10, 10)) - provider_factory("w", True, None) - with pytest.raises(ValueError): - provider_factory("r", False, None) - provider_factory("w", None, roi) - with pytest.raises(ValueError): - provider_factory("r", None, roi2) - - def test_graph_io(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -263,13 +260,20 @@ def test_graph_io(provider_factory): def test_graph_fail_if_exists(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -282,13 +286,20 @@ def test_graph_fail_if_exists(provider_factory): def test_graph_fail_if_not_exists(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -302,19 +313,26 @@ def test_graph_fail_if_not_exists(provider_factory): def test_graph_write_attributes(provider_factory): - graph_provider = provider_factory("w", node_attrs={"swip": str}) + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(int, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=[0, 0, 0]) + graph.add_node(42, position=[1, 1, 1]) + graph.add_node(23, position=[5, 5, 5], swip="swap") + graph.add_node(57, position=[7, 7, 7], zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) graph_provider.write_graph( - graph, write_nodes=True, write_edges=False, node_attrs=["swip"] + graph, write_nodes=True, write_edges=False, node_attrs=["position", "swip"] ) graph_provider.write_edges( @@ -337,17 +355,34 @@ def test_graph_write_attributes(provider_factory): compare_nodes = [ (node_id, data) for node_id, data in compare_nodes if len(data) > 0 ] - assert nodes == compare_nodes + for n, c in zip(nodes, compare_nodes): + assert n[0] == c[0] + for key in n[1]: + assert key in c[1] + v1 = n[1][key] + v2 = c[1][key] + try: + for e1, e2 in zip(v1, v2): + assert e1 == e2 + except: + assert v1 == v2 def test_graph_write_roi(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) @@ -373,13 +408,20 @@ def test_graph_write_roi(provider_factory): def test_graph_connected_components(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(57, 23) graph.add_edge(2, 42) try: @@ -404,15 +446,22 @@ def test_graph_connected_components(provider_factory): def test_graph_has_edge(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory( + "w", + node_attrs={ + "position": Array(float, 3), + "swip": str, + "zap": str, + } + ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] - graph.add_node(2, z=0, y=0, x=0) - graph.add_node(42, z=1, y=1, x=1) - graph.add_node(23, z=5, y=5, x=5, swip="swap") - graph.add_node(57, z=7, y=7, x=7, zap="zip") + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) From 9a2a0e6814583e8b3b3525456c85e38bc6397f8a Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Fri, 8 Mar 2024 10:04:55 -0500 Subject: [PATCH 03/11] Move Array to funlib.persistence.types --- funlib/persistence/graphs/__init__.py | 1 - funlib/persistence/graphs/graph_database.py | 2 +- funlib/persistence/graphs/pgsql_graph_database.py | 2 +- funlib/persistence/graphs/sql_graph_database.py | 2 +- funlib/persistence/graphs/sqlite_graph_database.py | 2 +- funlib/persistence/{graphs => }/types.py | 0 tests/test_graph.py | 2 +- 7 files changed, 5 insertions(+), 6 deletions(-) rename funlib/persistence/{graphs => }/types.py (100%) diff --git a/funlib/persistence/graphs/__init__.py b/funlib/persistence/graphs/__init__.py index f4627c1..6b767a2 100644 --- a/funlib/persistence/graphs/__init__.py +++ b/funlib/persistence/graphs/__init__.py @@ -1,3 +1,2 @@ from .sqlite_graph_database import SQLiteGraphDataBase # noqa from .pgsql_graph_database import PgSQLGraphDatabase # noqa -from .types import Array # noqa diff --git a/funlib/persistence/graphs/graph_database.py b/funlib/persistence/graphs/graph_database.py index b452f4f..94224dc 100644 --- a/funlib/persistence/graphs/graph_database.py +++ b/funlib/persistence/graphs/graph_database.py @@ -1,6 +1,6 @@ from networkx import Graph from funlib.geometry import Roi -from .types import Array +from ..types import Array import logging from abc import ABC, abstractmethod diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py index 6c701c5..b29c064 100644 --- a/funlib/persistence/graphs/pgsql_graph_database.py +++ b/funlib/persistence/graphs/pgsql_graph_database.py @@ -1,5 +1,5 @@ from .sql_graph_database import SQLGraphDataBase -from .types import Array +from ..types import Array from funlib.geometry import Roi import logging diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index f3af09b..f6801a3 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -1,5 +1,5 @@ from .graph_database import GraphDataBase, AttributeType -from .types import Array, type_to_str +from ..types import Array, type_to_str from funlib.geometry import Coordinate from funlib.geometry import Roi diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/funlib/persistence/graphs/sqlite_graph_database.py index 4c3760e..248d17b 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/funlib/persistence/graphs/sqlite_graph_database.py @@ -1,5 +1,5 @@ from .sql_graph_database import SQLGraphDataBase, AttributeType -from .types import Array +from ..types import Array from funlib.geometry import Roi diff --git a/funlib/persistence/graphs/types.py b/funlib/persistence/types.py similarity index 100% rename from funlib/persistence/graphs/types.py rename to funlib/persistence/types.py diff --git a/tests/test_graph.py b/tests/test_graph.py index 1ee986b..3855080 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,5 +1,5 @@ from funlib.geometry import Roi, Coordinate -from funlib.persistence.graphs import Array +from funlib.persistence.types import Array import networkx as nx import pytest From 28ccf8177b2917a76803f0fd6f4f2776592713b1 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Fri, 8 Mar 2024 10:06:57 -0500 Subject: [PATCH 04/11] Rename types.{Array -> Vec} to avoid name clash with funlib.persistence.Array --- funlib/persistence/graphs/graph_database.py | 4 +-- .../graphs/pgsql_graph_database.py | 4 +-- .../persistence/graphs/sql_graph_database.py | 8 +++--- .../graphs/sqlite_graph_database.py | 10 +++---- funlib/persistence/types.py | 6 ++-- tests/test_graph.py | 28 +++++++++---------- 6 files changed, 30 insertions(+), 30 deletions(-) diff --git a/funlib/persistence/graphs/graph_database.py b/funlib/persistence/graphs/graph_database.py index 94224dc..40f50a0 100644 --- a/funlib/persistence/graphs/graph_database.py +++ b/funlib/persistence/graphs/graph_database.py @@ -1,6 +1,6 @@ from networkx import Graph from funlib.geometry import Roi -from ..types import Array +from ..types import Vec import logging from abc import ABC, abstractmethod @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -AttributeType = type | str | Array +AttributeType = type | str | Vec class GraphDataBase(ABC): diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py index b29c064..7e4f3ed 100644 --- a/funlib/persistence/graphs/pgsql_graph_database.py +++ b/funlib/persistence/graphs/pgsql_graph_database.py @@ -1,5 +1,5 @@ from .sql_graph_database import SQLGraphDataBase -from ..types import Array +from ..types import Vec from funlib.geometry import Roi import logging @@ -186,7 +186,7 @@ def __sql_value(self, value): return str(value) def __sql_type(self, type): - if isinstance(type, Array): + if isinstance(type, Vec): return self.__sql_type(type.dtype) + f"[{type.size}]" try: return {bool: "BOOLEAN", int: "INTEGER", str: "VARCHAR", float: "REAL"}[ diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index f6801a3..6e5f7f8 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -1,5 +1,5 @@ from .graph_database import GraphDataBase, AttributeType -from ..types import Array, type_to_str +from ..types import Vec, type_to_str from funlib.geometry import Coordinate from funlib.geometry import Roi @@ -120,12 +120,12 @@ def get(value, default): ) position_type = node_attrs[self.position_attribute] - if isinstance(position_type, Array): + if isinstance(position_type, Vec): self.ndims = position_type.size assert self.ndims > 1, ( - "Don't use Arrays of size 1 for the position, use the " + "Don't use Vecs of size 1 for the position, use the " "scalar type directly instead (i.e., 'float' instead of " - "'Array(float, 1)'." + "'Vec(float, 1)'." ) # if ndims == 1, we know that we have a single scalar now else: diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/funlib/persistence/graphs/sqlite_graph_database.py index 248d17b..af1ce76 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/funlib/persistence/graphs/sqlite_graph_database.py @@ -1,5 +1,5 @@ from .sql_graph_database import SQLGraphDataBase, AttributeType -from ..types import Array +from ..types import Vec from funlib.geometry import Roi @@ -50,14 +50,14 @@ def __init__( f"{attr}_{d}" for d in range(attr_type.size) ] for attr, attr_type in self.node_attrs.items() - if isinstance(attr_type, Array) + if isinstance(attr_type, Vec) } self.edge_array_columns = { attr: [ f"{attr}_{d}" for d in range(attr_type.size) ] for attr, attr_type in self.edge_attrs.items() - if isinstance(attr_type, Array) + if isinstance(attr_type, Vec) } def _drop_tables(self) -> None: @@ -175,7 +175,7 @@ def _node_attrs_to_columns(self, attrs): columns = [] for attr in attrs: attr_type = self.node_attrs[attr] - if isinstance(attr_type, Array): + if isinstance(attr_type, Vec): columns += [ f"{attr}_{d}" for d in range(attr_type.size) ] @@ -200,7 +200,7 @@ def _edge_attrs_to_columns(self, attrs): columns = [] for attr in attrs: attr_type = self.edge_attrs[attr] - if isinstance(attr_type, Array): + if isinstance(attr_type, Vec): columns += [ f"{attr}_{d}" for d in range(attr_type.size) ] diff --git a/funlib/persistence/types.py b/funlib/persistence/types.py index c49f2d6..938bb9b 100644 --- a/funlib/persistence/types.py +++ b/funlib/persistence/types.py @@ -2,13 +2,13 @@ @dataclass -class Array: +class Vec: dtype: type | str size: int def type_to_str(type): - if isinstance(type, Array): - return f"Array({type_to_str(type.dtype)}, {type.size})" + if isinstance(type, Vec): + return f"Vec({type_to_str(type.dtype)}, {type.size})" else: return type.__name__ diff --git a/tests/test_graph.py b/tests/test_graph.py index 3855080..f738221 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,5 +1,5 @@ from funlib.geometry import Roi, Coordinate -from funlib.persistence.types import Array +from funlib.persistence.types import Vec import networkx as nx import pytest @@ -7,7 +7,7 @@ def test_graph_filtering(provider_factory): graph_writer = provider_factory( - "w", node_attrs={"position": Array(float, 3), "selected": bool}, edge_attrs={"selected": bool} + "w", node_attrs={"position": Vec(float, 3), "selected": bool}, edge_attrs={"selected": bool} ) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_writer[roi] @@ -54,7 +54,7 @@ def test_graph_filtering(provider_factory): def test_graph_filtering_complex(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, + node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) @@ -104,7 +104,7 @@ def test_graph_filtering_complex(provider_factory): def test_graph_read_and_update_specific_attrs(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, + node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int, "c": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) @@ -155,7 +155,7 @@ def test_graph_read_and_update_specific_attrs(provider_factory): def test_graph_read_unbounded_roi(provider_factory): graph_provider = provider_factory( "w", - node_attrs={"position": Array(float, 3), "selected": bool, "test": str}, + node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) roi = Roi((0, 0, 0), (10, 10, 10)) @@ -197,14 +197,14 @@ def test_graph_read_unbounded_roi(provider_factory): def test_graph_read_meta_values(provider_factory): roi = Roi((0, 0, 0), (10, 10, 10)) - provider_factory("w", True, roi, node_attrs={"position": Array(float, 3)}) + provider_factory("w", True, roi, node_attrs={"position": Vec(float, 3)}) graph_provider = provider_factory("r", None, None) assert True == graph_provider.directed assert roi == graph_provider.total_roi def test_graph_default_meta_values(provider_factory): - provider = provider_factory("w", False, None, node_attrs={"position": Array(float, 3)}) + provider = provider_factory("w", False, None, node_attrs={"position": Vec(float, 3)}) assert False == provider.directed assert provider.total_roi is None or provider.total_roi == Roi( (None, None, None), (None, None, None) @@ -220,7 +220,7 @@ def test_graph_io(provider_factory): graph_provider = provider_factory( "w", node_attrs={ - "position": Array(float, 3), + "position": Vec(float, 3), "swip": str, "zap": str, } @@ -263,7 +263,7 @@ def test_graph_fail_if_exists(provider_factory): graph_provider = provider_factory( "w", node_attrs={ - "position": Array(float, 3), + "position": Vec(float, 3), "swip": str, "zap": str, } @@ -289,7 +289,7 @@ def test_graph_fail_if_not_exists(provider_factory): graph_provider = provider_factory( "w", node_attrs={ - "position": Array(float, 3), + "position": Vec(float, 3), "swip": str, "zap": str, } @@ -316,7 +316,7 @@ def test_graph_write_attributes(provider_factory): graph_provider = provider_factory( "w", node_attrs={ - "position": Array(int, 3), + "position": Vec(int, 3), "swip": str, "zap": str, } @@ -372,7 +372,7 @@ def test_graph_write_roi(provider_factory): graph_provider = provider_factory( "w", node_attrs={ - "position": Array(float, 3), + "position": Vec(float, 3), "swip": str, "zap": str, } @@ -411,7 +411,7 @@ def test_graph_connected_components(provider_factory): graph_provider = provider_factory( "w", node_attrs={ - "position": Array(float, 3), + "position": Vec(float, 3), "swip": str, "zap": str, } @@ -449,7 +449,7 @@ def test_graph_has_edge(provider_factory): graph_provider = provider_factory( "w", node_attrs={ - "position": Array(float, 3), + "position": Vec(float, 3), "swip": str, "zap": str, } From 2896331d86918499be5bb4e68ef980b891442168 Mon Sep 17 00:00:00 2001 From: pattonw Date: Fri, 8 Mar 2024 07:08:43 -0800 Subject: [PATCH 05/11] improve error message on sqlite insert fail --- funlib/persistence/graphs/sqlite_graph_database.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/funlib/persistence/graphs/sqlite_graph_database.py index af1ce76..bed2533 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/funlib/persistence/graphs/sqlite_graph_database.py @@ -154,7 +154,12 @@ def _insert_query(self, table, columns, values, fail_if_exists=False, commit=Tru f"INSERT{' OR IGNORE' if not fail_if_exists else ''} INTO {table} " f"({', '.join(columns)}) VALUES ({', '.join(['?'] * len(columns))})" ) - self.cur.executemany(insert_statement, values) + try: + self.cur.executemany(insert_statement, values) + except sqlite3.IntegrityError as e: + raise ValueError( + f"Failed to insert values {values} with types {[[type(x) for x in row] for row in values]} into table {table} with columns {columns}" + ) from e if commit: self.con.commit() From b639b42e4593e47bdfeec485b950ed229c164bf0 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Fri, 8 Mar 2024 11:55:58 -0500 Subject: [PATCH 06/11] Throw more informative exception if graph DB metadata doesn't exist --- funlib/persistence/graphs/sql_graph_database.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index 6e5f7f8..8e35fd3 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -97,6 +97,8 @@ def __init__( self.ndims = None # to be read from metadata metadata = self._read_metadata() + if metadata is None: + RuntimeError("metadata does not exist, can't open in read mode") self.__load_metadata(metadata) if mode in self.create_modes: From 33b8227f3128ccdd2fe45edcc8a440a7a0b5c39a Mon Sep 17 00:00:00 2001 From: pattonw Date: Fri, 8 Mar 2024 08:59:40 -0800 Subject: [PATCH 07/11] raise Exception --- funlib/persistence/graphs/sql_graph_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index 8e35fd3..19c07fe 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -98,7 +98,7 @@ def __init__( metadata = self._read_metadata() if metadata is None: - RuntimeError("metadata does not exist, can't open in read mode") + raise RuntimeError("metadata does not exist, can't open in read mode") self.__load_metadata(metadata) if mode in self.create_modes: From 19cedacee93ea3730f1aace69bfd3cef5cd159ad Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Fri, 8 Mar 2024 14:11:52 -0500 Subject: [PATCH 08/11] Support Vec for non-position attributes in SQLiteGraphDB --- .../graphs/sqlite_graph_database.py | 70 ++++++++++++------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/funlib/persistence/graphs/sqlite_graph_database.py index bed2533..3c134c5 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/funlib/persistence/graphs/sqlite_graph_database.py @@ -32,6 +32,9 @@ def __init__( self.con = sqlite3.connect(db_file) self.cur = self.con.cursor() + self._node_array_columns = None + self._edge_array_columns = None + super().__init__( mode=mode, position_attribute=position_attribute, @@ -44,21 +47,29 @@ def __init__( edge_attrs=edge_attrs, ) - # in SQLite, array types are stored in individual columns - self.node_array_columns = { - attr: [ - f"{attr}_{d}" for d in range(attr_type.size) - ] - for attr, attr_type in self.node_attrs.items() - if isinstance(attr_type, Vec) - } - self.edge_array_columns = { - attr: [ - f"{attr}_{d}" for d in range(attr_type.size) - ] - for attr, attr_type in self.edge_attrs.items() - if isinstance(attr_type, Vec) - } + @property + def node_array_columns(self): + if not self._node_array_columns: + self._node_array_columns = { + attr: [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + for attr, attr_type in self.node_attrs.items() + if isinstance(attr_type, Vec) + } + return self._node_array_columns + + @property + def edge_array_columns(self): + if not self._edge_array_columns: + self._edge_array_columns = { + attr: [ + f"{attr}_{d}" for d in range(attr_type.size) + ] + for attr, attr_type in self.edge_attrs.items() + if isinstance(attr_type, Vec) + } + return self._edge_array_columns def _drop_tables(self) -> None: logger.info( @@ -75,27 +86,36 @@ def _drop_tables(self) -> None: self.meta_collection.unlink() def _create_tables(self) -> None: - columns = list(self.node_attrs.keys()) - columns.remove(self.position_attribute) - position_columns = [ - self.position_attribute + f"_{d}" - for d in range(self.ndims) - ] - columns += [ p + " REAL NOT NULL" for p in position_columns ] + + node_columns = ["id INTEGER not null PRIMARY KEY"] + for attr in self.node_attrs.keys(): + if attr in self.node_array_columns: + node_columns += self.node_array_columns[attr] + else: + node_columns.append(attr) + self.cur.execute( f"CREATE TABLE IF NOT EXISTS " f"{self.nodes_table_name}(" - "id INTEGER not null PRIMARY KEY, " - f"{', '.join(columns)}" + f"{', '.join(node_columns)}" ")" ) + if self.ndims > 1: + position_columns = self.node_array_columns[self.position_attribute] + else: + position_columns = self.position_attribute self.cur.execute( f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(position_columns)})" ) edge_columns = [ f"{self.endpoint_names[0]} INTEGER not null", f"{self.endpoint_names[1]} INTEGER not null", - ] + [f"{edge_attr}" for edge_attr in self.edge_attrs.keys()] + ] + for attr in self.edge_attrs.keys(): + if attr in self.edge_array_columns: + edge_columns += self.edge_array_columns[attr] + else: + edge_columns.append(attr) self.cur.execute( f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}(" + f"{', '.join(edge_columns)}" From ea69950b37c82f49b78ed7e58032f33c929b3111 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Fri, 8 Mar 2024 15:43:07 -0500 Subject: [PATCH 09/11] Change type of id column in PgSQLGraphDatabase to BIGINT --- funlib/persistence/graphs/pgsql_graph_database.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py index 7e4f3ed..8b8a411 100644 --- a/funlib/persistence/graphs/pgsql_graph_database.py +++ b/funlib/persistence/graphs/pgsql_graph_database.py @@ -92,7 +92,7 @@ def _create_tables(self) -> None: self.__exec( f"CREATE TABLE IF NOT EXISTS " f"{self.nodes_table_name}(" - "id INTEGER not null PRIMARY KEY, " + "id BIGINT not null PRIMARY KEY, " f"{', '.join(column_types)}" ")" ) @@ -106,8 +106,8 @@ def _create_tables(self) -> None: column_types = [f"{c} {t}" for c, t in zip(columns, types)] self.__exec( f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}(" - f"{self.endpoint_names[0]} INTEGER not null, " - f"{self.endpoint_names[1]} INTEGER not null, " + f"{self.endpoint_names[0]} BIGINT not null, " + f"{self.endpoint_names[1]} BIGINT not null, " f"{' '.join([c + ',' for c in column_types])}" f"PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})" ")" From 8db2a237170dfcb076634888f1e64acdc15e375f Mon Sep 17 00:00:00 2001 From: pattonw Date: Wed, 13 Mar 2024 12:31:41 -0700 Subject: [PATCH 10/11] handle None location case and remove print statements --- funlib/persistence/graphs/pgsql_graph_database.py | 1 - funlib/persistence/graphs/sql_graph_database.py | 14 ++++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py index 8b8a411..5a2f4b2 100644 --- a/funlib/persistence/graphs/pgsql_graph_database.py +++ b/funlib/persistence/graphs/pgsql_graph_database.py @@ -137,7 +137,6 @@ def _read_metadata(self) -> Optional[dict[str, Any]]: return None def _select_query(self, query) -> Iterable[Any]: - print(query) self.__exec(query) return self.cur diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index 19c07fe..38a73f5 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -345,11 +345,6 @@ def read_nodes( for values in self._select_query(select_statement) ] - for values in self._select_query(select_statement): - print(values) - print(self.node_attrs.keys()) - print(nodes) - return nodes def num_nodes(self, roi: Roi) -> int: @@ -445,7 +440,7 @@ def write_edges( u, v = min(u, v), max(u, v) pos_u = self.__get_node_pos(nodes[u]) - if not roi.contains(pos_u): + if pos_u is None or not roi.contains(pos_u): logger.debug( ( f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," @@ -660,8 +655,11 @@ def __remove_keys(self, dictionary, keys): return {k: v for k, v in dictionary.items() if k not in keys} - def __get_node_pos(self, n: dict[str, Any]) -> Coordinate: - return Coordinate(n[self.position_attribute]) + def __get_node_pos(self, n: dict[str, Any]) -> Optional[Coordinate]: + try: + return Coordinate(n[self.position_attribute]) + except KeyError: + return None def __convert_to_sql(self, x: Any) -> str: if isinstance(x, str): From 4b4e685913b7e588f60a86104d59a8a70149b75f Mon Sep 17 00:00:00 2001 From: William Patton Date: Mon, 18 Mar 2024 15:52:56 -0700 Subject: [PATCH 11/11] tests: remove older versions of python, add new versions --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 984cc6c..09262d0 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -23,7 +23,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10"] + python-version: ["3.10", "3.11", "3.12"] env: PGUSER: postgres