Skip to content

Commit

Permalink
pass final test
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Nov 16, 2023
1 parent 5f39387 commit 7a21df5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
23 changes: 13 additions & 10 deletions funlib/persistence/graphs/sqlite_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def read_nodes(
for key, val in zip(
["id"] + self.position_attributes + self.node_attrs, values
)
if key in read_attrs
if key in read_attrs and val is not None
}
for values in self.cur.execute(select_statement)
]
Expand All @@ -337,7 +337,7 @@ def read_nodes(

if isinstance(self.position_attribute, str):
for data in nodes:
data[self.position_attribute] = self.__get_node_pos(data)
data[self.position_attribute] = self.__combine_pos(data)

return nodes

Expand Down Expand Up @@ -605,32 +605,30 @@ def write_nodes(
raise NotImplementedError(
"Fail if exists not implemented for " "file backend"
)
if attributes is not None:
raise NotImplementedError("Attributes not implemented for file backend")
if self.mode == "r":
raise NotImplementedError("Trying to write to read-only DB")

logger.debug("Writing nodes in %s", roi)

attrs = attributes if attributes is not None else self.node_attrs

insert_statement = (
f"INSERT{' OR IGNORE' if not fail_if_exists else ''} INTO {self.nodes_collection_name} "
f"(id, {', '.join(self.position_attributes + self.node_attrs)}) VALUES "
f"({', '.join(['?'] * (len(self.position_attributes) + len(self.node_attrs) + 1))})"
f"(id, {', '.join(self.position_attributes + attrs)}) VALUES "
f"({', '.join(['?'] * (len(self.position_attributes) + len(attrs) + 1))})"
)

to_insert = []
for node_id, data in nodes.items():
data = data.copy()
pos = self.__get_node_pos(data)
if roi is not None and not roi.contains(pos):
continue
for i, position_attribute in enumerate(self.position_attributes):
data[position_attribute] = pos[i]
to_insert.append(
[node_id]
+ [
data.get(attr, None)
for attr in self.position_attributes + self.node_attrs
]
+ [data.get(attr, None) for attr in self.position_attributes + attrs]
)

if len(to_insert) == 0:
Expand Down Expand Up @@ -690,6 +688,11 @@ def __remove_keys(self, dictionary, keys):

return {k: v for k, v in dictionary.items() if k not in keys}

def __combine_pos(self, n: dict[str, Any]) -> Coordinate:
return Coordinate(
(n.pop(pos_attr, None) for pos_attr in self.position_attributes)
)

def __get_node_pos(self, n: dict[str, Any]) -> Coordinate:
if isinstance(self.position_attribute, str):
return Coordinate(n.get(self.position_attribute, (None,) * self.ndim))
Expand Down
9 changes: 5 additions & 4 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,21 @@ def test_graph_fail_if_not_exists(provider_factory):


def test_graph_write_attributes(provider_factory):
graph_provider = provider_factory("w")
graph_provider = provider_factory("w", node_attrs=["swip"])
graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))]

graph.add_node(2, comment="without position")
graph.add_node(42, position=(1, 1, 1))
graph.add_node(23, position=(5, 5, 5), swip="swap")
graph.add_node(57, position=Coordinate((7, 7, 7)), zap="zip")
graph.add_node(57, position=(7, 7, 7), zap="zip")
graph.add_edge(42, 23)
graph.add_edge(57, 23)
graph.add_edge(2, 42)

try:
graph_provider.write_nodes(graph.nodes(), attributes=["position", "swip"])
graph_provider.write_graph(
graph, write_nodes=True, write_edges=False, node_attrs=["swip"]
)
except NotImplementedError:
pytest.xfail()
graph_provider.write_edges(
Expand All @@ -325,7 +327,6 @@ def test_graph_write_attributes(provider_factory):
continue
if "zap" in data:
del data["zap"]
data["position"] = list(data["position"])
nodes.append((node, data))

compare_nodes = compare_graph.nodes(data=True)
Expand Down

0 comments on commit 7a21df5

Please sign in to comment.