Skip to content

Commit

Permalink
Move metadata handling into SQLGraphDatabase
Browse files Browse the repository at this point in the history
  • Loading branch information
funkey committed Nov 16, 2023
1 parent 817e9a2 commit 9f867c7
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 84 deletions.
84 changes: 82 additions & 2 deletions funlib/persistence/graphs/sql_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self._drop_tables()

self._create_tables()
self._init_metadata()
self.__init_metadata()

@abstractmethod
def _drop_tables(self) -> None:
Expand All @@ -97,7 +97,11 @@ def _create_tables(self) -> None:
pass

@abstractmethod
def _init_metadata(self) -> None:
def _store_metadata(self, metadata) -> None:
pass

@abstractmethod
def _read_metadata(self) -> dict[str, Any]:
pass

@abstractmethod
Expand Down Expand Up @@ -490,6 +494,82 @@ def update_nodes(

self._commit()

def __init_metadata(self):
metadata = self._read_metadata()

if metadata:
self.__check_metadata(metadata)
else:
metadata = self.__create_metadata()
self._store_metadata(metadata)

def __create_metadata(self):
"""Sets the metadata in the meta collection to the provided values"""

if not self.directed:
# default is False
self.directed = self.directed if self.directed is not None else False
if not self.total_roi:
# default is an unbounded roi
self.total_roi = Roi(
(None,) * len(self.position_attributes),
(None,) * len(self.position_attributes),
)

metadata = {
"directed": self.directed,
"total_roi_offset": self.total_roi.offset,
"total_roi_shape": self.total_roi.shape,
"node_attrs": self.node_attrs,
"edge_attrs": self.edge_attrs,
}

return metadata

def __check_metadata(self, metadata):
"""Checks if the provided metadata matches the existing
metadata in the meta collection"""

if self.directed is not None and metadata["directed"] != self.directed:
raise ValueError(
(
"Input parameter directed={} does not match"
"directed value {} already in stored metadata"
).format(self.directed, metadata["directed"])
)
elif self.directed is None:
self.directed = metadata["directed"]
if self.total_roi is not None:
if self.total_roi.get_offset() != metadata["total_roi_offset"]:
raise ValueError(
(
"Input total_roi offset {} does not match"
"total_roi offset {} already stored in metadata"
).format(self.total_roi.get_offset(), metadata["total_roi_offset"])
)
if self.total_roi.get_shape() != metadata["total_roi_shape"]:
raise ValueError(
(
"Input total_roi shape {} does not match"
"total_roi shape {} already stored in metadata"
).format(self.total_roi.get_shape(), metadata["total_roi_shape"])
)
else:
self.total_roi = Roi(
metadata["total_roi_offset"], metadata["total_roi_shape"]
)
if self._node_attrs is not None:
assert self.node_attrs == metadata["node_attrs"], (
self.node_attrs,
metadata["node_attrs"],
)
else:
self.node_attrs = metadata["node_attrs"]
if self._edge_attrs is not None:
assert self.edge_attrs == metadata["edge_attrs"]
else:
self.edge_attrs = metadata["edge_attrs"]

def __remove_keys(self, dictionary, keys):
"""Removes given keys from dictionary."""

Expand Down
93 changes: 11 additions & 82 deletions funlib/persistence/graphs/sqlite_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,81 +42,6 @@ def __init__(
edge_attrs=edge_attrs,
)

def _init_metadata(self):
if self.meta_collection.exists():
self.__check_metadata()
else:
self.__set_metadata()

def __check_metadata(self):
"""Checks if the provided metadata matches the existing
metadata in the meta collection"""

metadata = self.__get_metadata()
if self.directed is not None and metadata["directed"] != self.directed:
raise ValueError(
(
"Input parameter directed={} does not match"
"directed value {} already in stored metadata"
).format(self.directed, metadata["directed"])
)
elif self.directed is None:
self.directed = metadata["directed"]
if self.total_roi is not None:
if self.total_roi.get_offset() != metadata["total_roi_offset"]:
raise ValueError(
(
"Input total_roi offset {} does not match"
"total_roi offset {} already stored in metadata"
).format(self.total_roi.get_offset(), metadata["total_roi_offset"])
)
if self.total_roi.get_shape() != metadata["total_roi_shape"]:
raise ValueError(
(
"Input total_roi shape {} does not match"
"total_roi shape {} already stored in metadata"
).format(self.total_roi.get_shape(), metadata["total_roi_shape"])
)
else:
self.total_roi = Roi(
metadata["total_roi_offset"], metadata["total_roi_shape"]
)
if self._node_attrs is not None:
assert self.node_attrs == metadata["node_attrs"], (
self.node_attrs,
metadata["node_attrs"],
)
else:
self.node_attrs = metadata["node_attrs"]
if self._edge_attrs is not None:
assert self.edge_attrs == metadata["edge_attrs"]
else:
self.edge_attrs = metadata["edge_attrs"]

def __set_metadata(self):
"""Sets the metadata in the meta collection to the provided values"""

if not self.directed:
# default is False
self.directed = self.directed if self.directed is not None else False
if not self.total_roi:
# default is an unbounded roi
self.total_roi = Roi(
(None,) * len(self.position_attributes),
(None,) * len(self.position_attributes),
)

meta_data = {
"directed": self.directed,
"total_roi_offset": self.total_roi.offset,
"total_roi_shape": self.total_roi.shape,
"node_attrs": self.node_attrs,
"edge_attrs": self.edge_attrs,
}

with open(self.meta_collection, "w") as f:
json.dump(meta_data, f)

def _drop_tables(self) -> None:
logger.info(
"dropping collections %s, %s",
Expand Down Expand Up @@ -158,6 +83,17 @@ def _create_tables(self) -> None:
+ ")"
)

def _store_metadata(self, metadata):
with open(self.meta_collection, "w") as f:
json.dump(metadata, f)

def _read_metadata(self) -> dict[str, Any]:
if not self.meta_collection.exists():
return None

with open(self.meta_collection, "r") as f:
return json.load(f)

def _select_query(self, query):
try:
return self.cur.execute(query)
Expand Down Expand Up @@ -185,10 +121,3 @@ def _update_query(self, query, commit=True):

def _commit(self):
self.con.commit()

def __get_metadata(self) -> dict[str, Any]:
"""Gets metadata out of the meta collection and returns it
as a dictionary."""

with open(self.meta_collection, "r") as f:
return json.load(f)

0 comments on commit 9f867c7

Please sign in to comment.