From c514cf2ab837079209a8306fc8331a7b121c0b77 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Thu, 16 Nov 2023 20:16:42 -0500 Subject: [PATCH] Add PgSQLGraphDatabase --- funlib/persistence/graphs/__init__.py | 1 + .../graphs/pgsql_graph_database.py | 190 ++++++++++++++++++ .../persistence/graphs/sql_graph_database.py | 27 +-- tests/conftest.py | 19 +- 4 files changed, 217 insertions(+), 20 deletions(-) create mode 100644 funlib/persistence/graphs/pgsql_graph_database.py diff --git a/funlib/persistence/graphs/__init__.py b/funlib/persistence/graphs/__init__.py index 1360daf..6b767a2 100644 --- a/funlib/persistence/graphs/__init__.py +++ b/funlib/persistence/graphs/__init__.py @@ -1 +1,2 @@ from .sqlite_graph_database import SQLiteGraphDataBase # noqa +from .pgsql_graph_database import PgSQLGraphDatabase # noqa diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py new file mode 100644 index 0000000..467ac27 --- /dev/null +++ b/funlib/persistence/graphs/pgsql_graph_database.py @@ -0,0 +1,190 @@ +from .sql_graph_database import SQLGraphDataBase +from funlib.geometry import Roi + +import logging +import psycopg2 +import json +from typing import Optional, Any, Iterable + +logger = logging.getLogger(__name__) + + +class PgSQLGraphDatabase(SQLGraphDataBase): + def __init__( + self, + position_attributes: list[str], + db_name: str, + db_host: str = "localhost", + db_user: Optional[str] = None, + db_password: Optional[str] = None, + db_port: Optional[int] = None, + mode: str = "r+", + directed: Optional[bool] = None, + total_roi: Optional[Roi] = None, + nodes_table: str = "nodes", + edges_table: str = "edges", + endpoint_names: Optional[tuple[str, str]] = None, + node_attrs: Optional[dict[str, type]] = None, + edge_attrs: Optional[dict[str, type]] = None, + ): + self.db_host = db_host + self.db_name = db_name + self.db_user = db_user + self.db_password = db_password + self.db_port = db_port + + connection = psycopg2.connect( + host=db_host, + database="postgres", + user=db_user, + password=db_password, + port=db_port, + ) + connection.autocommit = True + cur = connection.cursor() + try: + cur.execute(f"CREATE DATABASE {db_name}") + except psycopg2.errors.DuplicateDatabase: + # DB already exists, moving on... + connection.rollback() + pass + self.connection = psycopg2.connect( + host=db_host, + database=db_name, + user=db_user, + password=db_password, + port=db_port, + ) + # TODO: remove once tests pass: + # self.connection.autocommit = True + self.cur = self.connection.cursor() + + super().__init__( + position_attributes, + mode=mode, + directed=directed, + total_roi=total_roi, + nodes_table=nodes_table, + edges_table=edges_table, + endpoint_names=endpoint_names, + node_attrs=node_attrs, + edge_attrs=edge_attrs, + ) + + def _drop_tables(self) -> None: + logger.info( + "dropping tables %s, %s", + self.nodes_table_name, + self.edges_table_name, + ) + + self.__exec(f"DROP TABLE IF EXISTS {self.nodes_table_name}") + self.__exec(f"DROP TABLE IF EXISTS {self.edges_table_name}") + self.__exec("DROP TABLE IF EXISTS metadata") + self._commit() + + def _create_tables(self) -> None: + columns = self.position_attributes + list(self.node_attrs.keys()) + types = [self.__sql_type(float) + " NOT NULL"] * len( + self.position_attributes + ) + list([self.__sql_type(t) for t in self.node_attrs.values()]) + column_types = [f"{c} {t}" for c, t in zip(columns, types)] + self.__exec( + f"CREATE TABLE IF NOT EXISTS " + f"{self.nodes_table_name}(" + "id INTEGER not null PRIMARY KEY, " + f"{', '.join(column_types)}" + ")" + ) + self.__exec( + f"CREATE INDEX IF NOT EXISTS pos_index ON " + f"{self.nodes_table_name}({','.join(self.position_attributes)})" + ) + + columns = list(self.edge_attrs.keys()) + types = list([self.__sql_type(t) for t in self.edge_attrs.values()]) + column_types = [f"{c} {t}" for c, t in zip(columns, types)] + self.__exec( + f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}(" + f"{self.endpoint_names[0]} INTEGER not null, " + f"{self.endpoint_names[1]} INTEGER not null, " + f"{' '.join([c + ',' for c in column_types])}" + f"PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})" + ")" + ) + + self._commit() + + def _store_metadata(self, metadata) -> None: + self.__exec("DROP TABLE IF EXISTS metadata") + self.__exec("CREATE TABLE metadata (value VARCHAR)") + self._insert_query( + "metadata", ["value"], [[json.dumps(metadata)]], fail_if_exists=True + ) + + def _read_metadata(self) -> dict[str, Any]: + try: + self.__exec("SELECT value FROM metadata") + except psycopg2.errors.UndefinedTable: + self.connection.rollback() + return None + + metadata = self.cur.fetchone()[0] + + return json.loads(metadata) + + def _select_query(self, query) -> Iterable[Any]: + self.__exec(query) + return self.cur + + def _insert_query( + self, table, columns, values, fail_if_exists=False, commit=True + ) -> None: + values_str = ( + "VALUES (" + + "), (".join( + [", ".join([self.__sql_value(v) for v in value]) for value in values] + ) + + ")" + ) + # TODO: fail_if_exists is the default if UNIQUE was used to create the + # table, we need to update if fail_if_exists==False + insert_statement = f"INSERT INTO {table}({', '.join(columns)}) " + values_str + self.__exec(insert_statement) + + if commit: + self.connection.commit() + + def _update_query(self, query, commit=True) -> None: + self.__exec(query) + + if commit: + self.connection.commit() + + def _commit(self) -> None: + self.connection.commit() + + def __exec(self, query): + try: + return self.cur.execute(query) + except: + self.connection.rollback() + raise + + def __sql_value(self, value): + if isinstance(value, str): + return f"'{value}'" + elif value is None: + return "NULL" + else: + return str(value) + + def __sql_type(self, type): + try: + return {bool: "BOOLEAN", int: "INTEGER", str: "VARCHAR", float: "REAL"}[ + type + ] + except ValueError: + raise NotImplementedError( + f"attributes of type {type} are not yet supported" + ) diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index 398c51e..5b927a7 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -254,7 +254,8 @@ def read_nodes( { key: val for key, val in zip( - ["id"] + self.position_attributes + list(self.node_attrs.keys()), values + ["id"] + self.position_attributes + list(self.node_attrs.keys()), + values, ) if key in read_attrs and val is not None } @@ -316,7 +317,9 @@ def read_edges( edges = [ { key: val - for key, val in zip(self.endpoint_names + list(self.edge_attrs.keys()), values) + for key, val in zip( + self.endpoint_names + list(self.edge_attrs.keys()), values + ) if key in edge_attrs } for values in self._select_query(select_statement) @@ -520,14 +523,8 @@ def __create_metadata(self): "directed": self.directed, "total_roi_offset": self.total_roi.offset, "total_roi_shape": self.total_roi.shape, - "node_attrs": { - k: v.__name__ - for k, v in self.node_attrs.items() - }, - "edge_attrs": { - k: v.__name__ - for k, v in self.edge_attrs.items() - }, + "node_attrs": {k: v.__name__ for k, v in self.node_attrs.items()}, + "edge_attrs": {k: v.__name__ for k, v in self.edge_attrs.items()}, } return metadata @@ -564,14 +561,8 @@ def __check_metadata(self, metadata): self.total_roi = Roi( metadata["total_roi_offset"], metadata["total_roi_shape"] ) - metadata["node_attrs"] = { - k: eval(v) - for k, v in metadata["node_attrs"].items() - } - metadata["edge_attrs"] = { - k: eval(v) - for k, v in metadata["edge_attrs"].items() - } + metadata["node_attrs"] = {k: eval(v) for k, v in metadata["node_attrs"].items()} + metadata["edge_attrs"] = {k: eval(v) for k, v in metadata["edge_attrs"].items()} if self._node_attrs is not None: assert self.node_attrs == metadata["node_attrs"], ( self.node_attrs, diff --git a/tests/conftest.py b/tests/conftest.py index 43f3f26..995ce49 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from funlib.persistence.graphs import SQLiteGraphDataBase +from funlib.persistence.graphs import SQLiteGraphDataBase, PgSQLGraphDatabase import pytest import pymongo @@ -6,7 +6,7 @@ from pathlib import Path -@pytest.fixture(params=(pytest.param("sqlite"),)) +@pytest.fixture(params=(pytest.param("sqlite"), pytest.param("psql"))) def provider_factory(request, tmpdir): # provides a factory function to generate graph provider # can provide either mongodb graph provider or file graph provider @@ -28,7 +28,22 @@ def sqlite_provider_factory( edge_attrs=edge_attrs, ) + def psql_provider_factory( + mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None + ): + return PgSQLGraphDatabase( + position_attributes=["z", "y", "x"], + db_name="pytest", + mode=mode, + directed=directed, + total_roi=total_roi, + node_attrs=node_attrs, + edge_attrs=edge_attrs, + ) + if request.param == "sqlite": yield sqlite_provider_factory + elif request.param == "psql": + yield psql_provider_factory else: raise ValueError()