diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/funlib/persistence/graphs/sqlite_graph_database.py index 1a42974..d70bc45 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/funlib/persistence/graphs/sqlite_graph_database.py @@ -328,7 +328,7 @@ def read_nodes( for key, val in zip( ["id"] + self.position_attributes + self.node_attrs, values ) - if key in read_attrs + if key in read_attrs and val is not None } for values in self.cur.execute(select_statement) ] @@ -337,7 +337,7 @@ def read_nodes( if isinstance(self.position_attribute, str): for data in nodes: - data[self.position_attribute] = self.__get_node_pos(data) + data[self.position_attribute] = self.__combine_pos(data) return nodes @@ -605,21 +605,22 @@ def write_nodes( raise NotImplementedError( "Fail if exists not implemented for " "file backend" ) - if attributes is not None: - raise NotImplementedError("Attributes not implemented for file backend") if self.mode == "r": raise NotImplementedError("Trying to write to read-only DB") logger.debug("Writing nodes in %s", roi) + attrs = attributes if attributes is not None else self.node_attrs + insert_statement = ( f"INSERT{' OR IGNORE' if not fail_if_exists else ''} INTO {self.nodes_collection_name} " - f"(id, {', '.join(self.position_attributes + self.node_attrs)}) VALUES " - f"({', '.join(['?'] * (len(self.position_attributes) + len(self.node_attrs) + 1))})" + f"(id, {', '.join(self.position_attributes + attrs)}) VALUES " + f"({', '.join(['?'] * (len(self.position_attributes) + len(attrs) + 1))})" ) to_insert = [] for node_id, data in nodes.items(): + data = data.copy() pos = self.__get_node_pos(data) if roi is not None and not roi.contains(pos): continue @@ -627,10 +628,7 @@ def write_nodes( data[position_attribute] = pos[i] to_insert.append( [node_id] - + [ - data.get(attr, None) - for attr in self.position_attributes + self.node_attrs - ] + + [data.get(attr, None) for attr in self.position_attributes + attrs] ) if len(to_insert) == 0: @@ -690,6 +688,11 @@ def __remove_keys(self, dictionary, keys): return {k: v for k, v in dictionary.items() if k not in keys} + def __combine_pos(self, n: dict[str, Any]) -> Coordinate: + return Coordinate( + (n.pop(pos_attr, None) for pos_attr in self.position_attributes) + ) + def __get_node_pos(self, n: dict[str, Any]) -> Coordinate: if isinstance(self.position_attribute, str): return Coordinate(n.get(self.position_attribute, (None,) * self.ndim)) diff --git a/tests/test_graph.py b/tests/test_graph.py index 1e86e5d..ea0857a 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -296,19 +296,21 @@ def test_graph_fail_if_not_exists(provider_factory): def test_graph_write_attributes(provider_factory): - graph_provider = provider_factory("w") + graph_provider = provider_factory("w", node_attrs=["swip"]) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] graph.add_node(2, comment="without position") graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") - graph.add_node(57, position=Coordinate((7, 7, 7)), zap="zip") + graph.add_node(57, position=(7, 7, 7), zap="zip") graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) try: - graph_provider.write_nodes(graph.nodes(), attributes=["position", "swip"]) + graph_provider.write_graph( + graph, write_nodes=True, write_edges=False, node_attrs=["swip"] + ) except NotImplementedError: pytest.xfail() graph_provider.write_edges( @@ -325,7 +327,6 @@ def test_graph_write_attributes(provider_factory): continue if "zap" in data: del data["zap"] - data["position"] = list(data["position"]) nodes.append((node, data)) compare_nodes = compare_graph.nodes(data=True)