Skip to content

Commit

Permalink
Merge branch 'funkelab:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Mar 19, 2024
2 parents 1dd1c45 + 4b4e685 commit 8d65b0c
Show file tree
Hide file tree
Showing 8 changed files with 493 additions and 221 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.10", "3.11", "3.12"]

env:
PGUSER: postgres
Expand Down
8 changes: 6 additions & 2 deletions funlib/persistence/graphs/graph_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from networkx import Graph
from funlib.geometry import Roi
from ..types import Vec

import logging
from abc import ABC, abstractmethod
Expand All @@ -9,6 +10,9 @@
logger = logging.getLogger(__name__)


AttributeType = type | str | Vec


class GraphDataBase(ABC):
"""
Interface for graph databases that supports slicing to retrieve
Expand All @@ -33,15 +37,15 @@ def __getitem__(self, roi) -> Graph:

@property
@abstractmethod
def node_attrs(self) -> dict[str, type]:
def node_attrs(self) -> dict[str, AttributeType]:
"""
Return the node attributes supported by the database.
"""
pass

@property
@abstractmethod
def edge_attrs(self) -> dict[str, type]:
def edge_attrs(self) -> dict[str, AttributeType]:
"""
Return the edge attributes supported by the database.
"""
Expand Down
24 changes: 14 additions & 10 deletions funlib/persistence/graphs/pgsql_graph_database.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from .sql_graph_database import SQLGraphDataBase
from ..types import Vec
from funlib.geometry import Roi

import logging
import psycopg2
import json
from typing import Optional, Any, Iterable
from collections.abc import Iterable

logger = logging.getLogger(__name__)


class PgSQLGraphDatabase(SQLGraphDataBase):
def __init__(
self,
position_attributes: list[str],
position_attribute: str,
db_name: str,
db_host: str = "localhost",
db_user: Optional[str] = None,
Expand Down Expand Up @@ -60,8 +62,8 @@ def __init__(
self.cur = self.connection.cursor()

super().__init__(
position_attributes,
mode=mode,
position_attribute=position_attribute,
directed=directed,
total_roi=total_roi,
nodes_table=nodes_table,
Expand All @@ -84,30 +86,28 @@ def _drop_tables(self) -> None:
self._commit()

def _create_tables(self) -> None:
columns = self.position_attributes + list(self.node_attrs.keys())
types = [self.__sql_type(float) + " NOT NULL"] * len(
self.position_attributes
) + list([self.__sql_type(t) for t in self.node_attrs.values()])
columns = self.node_attrs.keys()
types = [self.__sql_type(t) for t in self.node_attrs.values()]
column_types = [f"{c} {t}" for c, t in zip(columns, types)]
self.__exec(
f"CREATE TABLE IF NOT EXISTS "
f"{self.nodes_table_name}("
"id INTEGER not null PRIMARY KEY, "
"id BIGINT not null PRIMARY KEY, "
f"{', '.join(column_types)}"
")"
)
self.__exec(
f"CREATE INDEX IF NOT EXISTS pos_index ON "
f"{self.nodes_table_name}({','.join(self.position_attributes)})"
f"{self.nodes_table_name}({self.position_attribute})"
)

columns = list(self.edge_attrs.keys())
types = list([self.__sql_type(t) for t in self.edge_attrs.values()])
column_types = [f"{c} {t}" for c, t in zip(columns, types)]
self.__exec(
f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}("
f"{self.endpoint_names[0]} INTEGER not null, "
f"{self.endpoint_names[1]} INTEGER not null, "
f"{self.endpoint_names[0]} BIGINT not null, "
f"{self.endpoint_names[1]} BIGINT not null, "
f"{' '.join([c + ',' for c in column_types])}"
f"PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})"
")"
Expand Down Expand Up @@ -177,12 +177,16 @@ def __exec(self, query):
def __sql_value(self, value):
if isinstance(value, str):
return f"'{value}'"
if isinstance(value, Iterable):
return f"array[{','.join([self.__sql_value(v) for v in value])}]"
elif value is None:
return "NULL"
else:
return str(value)

def __sql_type(self, type):
if isinstance(type, Vec):
return self.__sql_type(type.dtype) + f"[{type.size}]"
try:
return {bool: "BOOLEAN", int: "INTEGER", str: "VARCHAR", float: "REAL"}[
type
Expand Down
Loading

0 comments on commit 8d65b0c

Please sign in to comment.