Skip to content

Commit

Permalink
added test to verify .db extension
Browse files Browse the repository at this point in the history
  • Loading branch information
AbishekRajVG committed Oct 20, 2023
1 parent 4b3ed8d commit ee2b156
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
34 changes: 34 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,3 +1780,37 @@ def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None:
# check correct error is raised if coordinates are missing
with pytest.raises(ValueError, match="coordinates"):
misc.patch_pred_store(patch_output, (1.0, 1.0))


def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None:
"""Test patch_pred_store and ensures the output file extension is `.db`."""
# Define a mock patch_output
patch_output = {
"predictions": [1, 0, 1],
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
"probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]],
"labels": [1, 0, 1],
}

# sends the path of a jpeg source image, expects .db file in the same directory
save_path = tmp_path / "patch_output" / "output.jpeg"

store_path = misc.patch_pred_store(patch_output, (1.0, 1.0), save_path=save_path)

print("Annotation store path: ", store_path)
assert Path.exists(store_path), "Annotation Store output file does not exist"

store = SQLiteStore(store_path)

# Check that its an SQLiteStore containing the expected annotations
assert isinstance(store, SQLiteStore)
assert len(store) == 3
for annotation in store.values():
assert annotation.geometry.area == 1
assert annotation.properties["type"] in [0, 1]
assert "other" not in annotation.properties

patch_output.pop("coordinates")
# check correct error is raised if coordinates are missing
with pytest.raises(ValueError, match="coordinates"):
misc.patch_pred_store(patch_output, (1.0, 1.0))
6 changes: 2 additions & 4 deletions tiatoolbox/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,8 +1268,7 @@ def patch_pred_store(
# ensure parent directory exisits
save_path.parent.absolute().mkdir(parents=True, exist_ok=True)
# ensure proper db extension
if save_path.suffix != ".db":
save_path = save_path.parent.absolute() / (save_path.stem + ".db")
save_path = save_path.parent.absolute() / (save_path.stem + ".db")
store.dump(save_path)
return save_path

Expand Down Expand Up @@ -1300,8 +1299,7 @@ def patch_pred_store_zarr(
chunks = kwargs["chunks"] if "chunks" in kwargs else 10000

# ensure proper zarr extension
if save_path.suffix != ".zarr":
save_path = save_path.parent.absolute() / (save_path.stem + ".zarr")
save_path = save_path.parent.absolute() / (save_path.stem + ".zarr")

# save to zarr
predictions_array = np.array(raw_predictions["predictions"])
Expand Down

0 comments on commit ee2b156

Please sign in to comment.