Skip to content

Commit

Permalink
Add PgSQLGraphDatabase
Browse files Browse the repository at this point in the history
  • Loading branch information
funkey committed Nov 17, 2023
1 parent 685ee32 commit c514cf2
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 20 deletions.
1 change: 1 addition & 0 deletions funlib/persistence/graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .sqlite_graph_database import SQLiteGraphDataBase # noqa
from .pgsql_graph_database import PgSQLGraphDatabase # noqa
190 changes: 190 additions & 0 deletions funlib/persistence/graphs/pgsql_graph_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from .sql_graph_database import SQLGraphDataBase
from funlib.geometry import Roi

import logging
import psycopg2
import json
from typing import Optional, Any, Iterable

logger = logging.getLogger(__name__)


class PgSQLGraphDatabase(SQLGraphDataBase):
def __init__(
self,
position_attributes: list[str],
db_name: str,
db_host: str = "localhost",
db_user: Optional[str] = None,
db_password: Optional[str] = None,
db_port: Optional[int] = None,
mode: str = "r+",
directed: Optional[bool] = None,
total_roi: Optional[Roi] = None,
nodes_table: str = "nodes",
edges_table: str = "edges",
endpoint_names: Optional[tuple[str, str]] = None,
node_attrs: Optional[dict[str, type]] = None,
edge_attrs: Optional[dict[str, type]] = None,
):
self.db_host = db_host
self.db_name = db_name
self.db_user = db_user
self.db_password = db_password
self.db_port = db_port

connection = psycopg2.connect(
host=db_host,
database="postgres",
user=db_user,
password=db_password,
port=db_port,
)
connection.autocommit = True
cur = connection.cursor()
try:
cur.execute(f"CREATE DATABASE {db_name}")
except psycopg2.errors.DuplicateDatabase:
# DB already exists, moving on...
connection.rollback()
pass
self.connection = psycopg2.connect(
host=db_host,
database=db_name,
user=db_user,
password=db_password,
port=db_port,
)
# TODO: remove once tests pass:
# self.connection.autocommit = True
self.cur = self.connection.cursor()

super().__init__(
position_attributes,
mode=mode,
directed=directed,
total_roi=total_roi,
nodes_table=nodes_table,
edges_table=edges_table,
endpoint_names=endpoint_names,
node_attrs=node_attrs,
edge_attrs=edge_attrs,
)

def _drop_tables(self) -> None:
logger.info(
"dropping tables %s, %s",
self.nodes_table_name,
self.edges_table_name,
)

self.__exec(f"DROP TABLE IF EXISTS {self.nodes_table_name}")
self.__exec(f"DROP TABLE IF EXISTS {self.edges_table_name}")
self.__exec("DROP TABLE IF EXISTS metadata")
self._commit()

def _create_tables(self) -> None:
columns = self.position_attributes + list(self.node_attrs.keys())
types = [self.__sql_type(float) + " NOT NULL"] * len(
self.position_attributes
) + list([self.__sql_type(t) for t in self.node_attrs.values()])
column_types = [f"{c} {t}" for c, t in zip(columns, types)]
self.__exec(
f"CREATE TABLE IF NOT EXISTS "
f"{self.nodes_table_name}("
"id INTEGER not null PRIMARY KEY, "
f"{', '.join(column_types)}"
")"
)
self.__exec(
f"CREATE INDEX IF NOT EXISTS pos_index ON "
f"{self.nodes_table_name}({','.join(self.position_attributes)})"
)

columns = list(self.edge_attrs.keys())
types = list([self.__sql_type(t) for t in self.edge_attrs.values()])
column_types = [f"{c} {t}" for c, t in zip(columns, types)]
self.__exec(
f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}("
f"{self.endpoint_names[0]} INTEGER not null, "
f"{self.endpoint_names[1]} INTEGER not null, "
f"{' '.join([c + ',' for c in column_types])}"
f"PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})"
")"
)

self._commit()

def _store_metadata(self, metadata) -> None:
self.__exec("DROP TABLE IF EXISTS metadata")
self.__exec("CREATE TABLE metadata (value VARCHAR)")
self._insert_query(
"metadata", ["value"], [[json.dumps(metadata)]], fail_if_exists=True
)

def _read_metadata(self) -> dict[str, Any]:
try:
self.__exec("SELECT value FROM metadata")
except psycopg2.errors.UndefinedTable:
self.connection.rollback()
return None

metadata = self.cur.fetchone()[0]

return json.loads(metadata)

def _select_query(self, query) -> Iterable[Any]:
self.__exec(query)
return self.cur

def _insert_query(
self, table, columns, values, fail_if_exists=False, commit=True
) -> None:
values_str = (
"VALUES ("
+ "), (".join(
[", ".join([self.__sql_value(v) for v in value]) for value in values]
)
+ ")"
)
# TODO: fail_if_exists is the default if UNIQUE was used to create the
# table, we need to update if fail_if_exists==False
insert_statement = f"INSERT INTO {table}({', '.join(columns)}) " + values_str
self.__exec(insert_statement)

if commit:
self.connection.commit()

def _update_query(self, query, commit=True) -> None:
self.__exec(query)

if commit:
self.connection.commit()

def _commit(self) -> None:
self.connection.commit()

def __exec(self, query):
try:
return self.cur.execute(query)
except:
self.connection.rollback()
raise

def __sql_value(self, value):
if isinstance(value, str):
return f"'{value}'"
elif value is None:
return "NULL"
else:
return str(value)

def __sql_type(self, type):
try:
return {bool: "BOOLEAN", int: "INTEGER", str: "VARCHAR", float: "REAL"}[
type
]
except ValueError:
raise NotImplementedError(
f"attributes of type {type} are not yet supported"
)
27 changes: 9 additions & 18 deletions funlib/persistence/graphs/sql_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def read_nodes(
{
key: val
for key, val in zip(
["id"] + self.position_attributes + list(self.node_attrs.keys()), values
["id"] + self.position_attributes + list(self.node_attrs.keys()),
values,
)
if key in read_attrs and val is not None
}
Expand Down Expand Up @@ -316,7 +317,9 @@ def read_edges(
edges = [
{
key: val
for key, val in zip(self.endpoint_names + list(self.edge_attrs.keys()), values)
for key, val in zip(
self.endpoint_names + list(self.edge_attrs.keys()), values
)
if key in edge_attrs
}
for values in self._select_query(select_statement)
Expand Down Expand Up @@ -520,14 +523,8 @@ def __create_metadata(self):
"directed": self.directed,
"total_roi_offset": self.total_roi.offset,
"total_roi_shape": self.total_roi.shape,
"node_attrs": {
k: v.__name__
for k, v in self.node_attrs.items()
},
"edge_attrs": {
k: v.__name__
for k, v in self.edge_attrs.items()
},
"node_attrs": {k: v.__name__ for k, v in self.node_attrs.items()},
"edge_attrs": {k: v.__name__ for k, v in self.edge_attrs.items()},
}

return metadata
Expand Down Expand Up @@ -564,14 +561,8 @@ def __check_metadata(self, metadata):
self.total_roi = Roi(
metadata["total_roi_offset"], metadata["total_roi_shape"]
)
metadata["node_attrs"] = {
k: eval(v)
for k, v in metadata["node_attrs"].items()
}
metadata["edge_attrs"] = {
k: eval(v)
for k, v in metadata["edge_attrs"].items()
}
metadata["node_attrs"] = {k: eval(v) for k, v in metadata["node_attrs"].items()}
metadata["edge_attrs"] = {k: eval(v) for k, v in metadata["edge_attrs"].items()}
if self._node_attrs is not None:
assert self.node_attrs == metadata["node_attrs"], (
self.node_attrs,
Expand Down
19 changes: 17 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from funlib.persistence.graphs import SQLiteGraphDataBase
from funlib.persistence.graphs import SQLiteGraphDataBase, PgSQLGraphDatabase

import pytest
import pymongo

from pathlib import Path


@pytest.fixture(params=(pytest.param("sqlite"),))
@pytest.fixture(params=(pytest.param("sqlite"), pytest.param("psql")))
def provider_factory(request, tmpdir):
# provides a factory function to generate graph provider
# can provide either mongodb graph provider or file graph provider
Expand All @@ -28,7 +28,22 @@ def sqlite_provider_factory(
edge_attrs=edge_attrs,
)

def psql_provider_factory(
mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None
):
return PgSQLGraphDatabase(
position_attributes=["z", "y", "x"],
db_name="pytest",
mode=mode,
directed=directed,
total_roi=total_roi,
node_attrs=node_attrs,
edge_attrs=edge_attrs,
)

if request.param == "sqlite":
yield sqlite_provider_factory
elif request.param == "psql":
yield psql_provider_factory
else:
raise ValueError()

0 comments on commit c514cf2

Please sign in to comment.