Skip to content

Commit

Permalink
Merge branch 'main' into napari-join-labels
Browse files Browse the repository at this point in the history
  • Loading branch information
nfahlgren committed May 6, 2024
2 parents ba0c941 + 9fe576f commit d029275
Show file tree
Hide file tree
Showing 13 changed files with 730 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.9', '3.10', '3.11']
os: [ubuntu-latest]
env:
OS: ${{ matrix.os }}
Expand Down
44 changes: 44 additions & 0 deletions docs/get_centroids.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
## Get Centroids

Extract the centroid coordinate (column,row) from regions in a binary image.

**plantcv.annotate.get_centroids**(*bin_img*)

**returns** list containing coordinates of centroids

- **Parameters:**
- bin_img - Binary image containing the connected regions to consider
- **Context:**
- Given an arbitrary mask of the objects of interest, `get_centroids`
returns a list of coordinates that can the be imported into the annotation class [Points](Points.md).

- **Example use:**
- Below

**Binary image**

![count_img](img/documentation_images/get_centroids/discs_mask.png)

```python

from plantcv import plantcv as pcv

# Set global debug behavior to None (default), "print" (to file),
# or "plot"
pcv.params.debug = "plot"

# Apply get centroids to the binary image
coords = pcv.annotate.get_centroids(bin_img=binary_img)
print(coords)
# [[1902, 600], [1839, 1363], [1837, 383], [1669, 1977], [1631, 1889], [1590, 1372], [1550, 1525],
# [1538, 1633], [1522, 1131], [1494, 2396], [1482, 1917], [1446, 1808], [1425, 726], [1418, 2392],
# [1389, 198], [1358, 1712], [1288, 522], [1289, 406], [1279, 368], [1262, 1376], [1244, 1795],
# [1224, 1327], [1201, 624], [1181, 725], [1062, 85], [999, 840], [885, 399], [740, 324], [728, 224],
# [697, 860], [660, 650], [638, 2390], [622, 1565], [577, 497], [572, 2179], [550, 2230], [547, 1826],
# [537, 892], [538, 481], [524, 2144], [521, 2336], [497, 201], [385, 1141], [342, 683], [342, 102],
# [332, 1700], [295, 646], [271, 60], [269, 1626], [210, 1694], [189, 878], [178, 1570], [171, 2307],
# [61, 286], [28, 2342]]

```

**Source Code:** [Here](https://github.com/danforthcenter/plantcv-annotate/blob/main/plantcv/annotate/get_centroids.py)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ nav:
- Napari Label: napari_label_classes.md
- Napari Open: napari_open.md
- Points: Points.md

- Get Centroids: get_centroids.md
markdown_extensions:
- toc:
permalink: True
Expand Down
9 changes: 8 additions & 1 deletion plantcv/annotate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from importlib.metadata import version
from plantcv.annotate.classes import Points
from plantcv.annotate.get_centroids import get_centroids
from plantcv.annotate.napari_classes import napari_classes
from plantcv.annotate.napari_open import napari_open
from plantcv.annotate.napari_label_classes import napari_label_classes
from plantcv.annotate.napari_join_labels import napari_join_labels

# Auto versioning
__version__ = version("plantcv-annotate")

__all__ = [
"Points",
"get_centroids",
"napari_classes",
"napari_open",
"napari_label_classes",
"napari_join_labels"
]
]
160 changes: 160 additions & 0 deletions plantcv/annotate/classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Class helpers

# Imports
import cv2
import json
from math import floor
import matplotlib.pyplot as plt
from plantcv.plantcv.annotate.points import _find_closest_pt
from plantcv.plantcv import warn


class Points:
"""Point annotation/collection class to use in Jupyter notebooks. It allows the user to
interactively click to collect coordinates from an image. Left click collects the point and
right click removes the closest collected point.
"""

def __init__(self, img, figsize=(12, 6), label="default", color="r", view_all=False):
"""Points initialization method.
Parameters
----------
img : numpy.ndarray
image to annotate
figsize : tuple, optional
figure plotting size, by default (12, 6)
label : str, optional
class label, by default "default"
"""
self.img = img
self.figsize = figsize
self.label = label # current label
self.color = color # current color
self.view_all = view_all # a flag indicating whether or not view all labels
self.coords = {} # dictionary of all coordinates per group label
self.events = [] # includes right and left click events
self.count = {} # a dictionary that saves the counts of different groups (labels)
self.sample_labels = [] # list of all sample labels, one to one with points collected
self.colors = {} # all used colors

self.view(label=self.label, color=self.color, view_all=self.view_all)

def onclick(self, event):
"""Handle mouse click events
Parameters
----------
event : matplotlib.backend_bases.MouseEvent
matplotlib MouseEvent object
"""
print(type(event))
self.events.append(event)
if event.button == 1:
# Add point to the plot
self.ax.plot(event.xdata, event.ydata, marker='x', c=self.color)
self.coords[self.label].append((floor(event.xdata), floor(event.ydata)))
self.count[self.label] += 1
self.sample_labels.append(self.label)
else:
idx_remove, _ = _find_closest_pt((event.xdata, event.ydata), self.coords[self.label])
# remove the closest point to the user right clicked one
self.coords[self.label].pop(idx_remove)
self.count[self.label] -= 1
idx_remove = idx_remove + self.p_not_current
self.ax.lines[idx_remove].remove()
self.sample_labels.pop(idx_remove)
self.fig.canvas.draw()

def print_coords(self, filename):
"""Save collected coordinates to a file.
Parameters
----------
filename : str
output filename
"""
# Open the file for writing
with open(filename, "w") as fp:
# Save the data in JSON format with indentation
json.dump(obj=self.coords, fp=fp, indent=4)

def import_list(self, coords, label="default"):
"""Import coordinates.
Parameters
----------
coords : list
list of coordinates (tuples)
label : str, optional
class label, by default "default"
"""
if label not in self.coords:
self.coords[label] = []
for (x, y) in coords:
self.coords[label].append((x, y))
self.count[label] = len(self.coords[label])
self.view(label=label, color=self.color, view_all=False)
else:
warn(f"{label} already included and counted, nothing is imported!")

def import_file(self, filename):
"""Import coordinates from a file.
Parameters
----------
filename : str
JSON file containing Points annotations
"""
with open(filename, "r") as fp:
coords = json.load(fp)

keys = list(coords.keys())

for key in keys:
keycoor = coords[key]
keycoor = list(map(lambda sub: (sub[1], sub[0]), keycoor))
self.import_list(keycoor, label=key)

def view(self, label="default", color="r", view_all=False):
"""View coordinates for a specific class label.
Parameters
----------
label : str, optional
class label, by default "default"
color : str, optional
marker color, by default "r"
view_all : bool, optional
view all classes or a single class, by default False
"""
if label not in self.coords and color in self.colors.values():
warn("The color assigned to the new class label is already used, if proceeding, "
"items from different classes will not be distinguishable in plots!")
self.label = label
self.color = color
self.view_all = view_all

if self.label not in self.coords:
self.coords[self.label] = []
self.count[self.label] = 0
self.colors[self.label] = self.color

self.fig, self.ax = plt.subplots(1, 1, figsize=self.figsize)

self.events = []
self.fig.canvas.mpl_connect('button_press_event', self.onclick)

self.ax.imshow(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
self.ax.set_title("Please left click on objects\n Right click to remove")
self.p_not_current = 0
# if view_all is True, show all already marked markers
if self.view_all:
for k in self.coords:
for (x, y) in self.coords[k]:
self.ax.plot(x, y, marker='x', c=self.colors[k])
if self.label not in self.coords or len(self.coords[self.label]) == 0:
self.p_not_current += 1
else:
for (x, y) in self.coords[self.label]:
self.ax.plot(x, y, marker='x', c=self.color)
28 changes: 28 additions & 0 deletions plantcv/annotate/get_centroids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Get centroids from mask objects

from skimage.measure import label, regionprops


def get_centroids(bin_img):
"""Get the coordinates (row,column) of the centroid of each connected region in a binary image.
Inputs:
bin_img = Binary image containing the connected regions to consider
Returns:
coords = List of coordinates (row,column) of the centroids of the regions
:param bin_img: numpy.ndarray
:return coords: list
"""
# find contours in the binary image
labeled_img = label(bin_img)
# measure regions
obj_measures = regionprops(labeled_img)
coords = []
for obj in obj_measures:
# Convert coord values to int
coord = tuple(map(int, obj.centroid))
coords.append(coord)

return coords
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools >= 61.0"]
requires = ["setuptools >= 64.0", "setuptools_scm>=8"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
Expand Down Expand Up @@ -27,12 +27,17 @@ classifiers = [
"License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)",
"Intended Audience :: Science/Research",
]

[project.optional-dependencies]
test = [
"pytest",
"pytest-cov",
"pytest-qt"
]

[project.urls]
Homepage = "https://plantcv.org"
Documentation = "https://plantcv.readthedocs.io"
Repository = "https://github.com/danforthcenter/plantcv-annotate"
Repository = "https://github.com/danforthcenter/plantcv-annotate"

[tool.setuptools_scm]
20 changes: 11 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import pytest
import matplotlib
import os


# Disable plotting
Expand All @@ -11,17 +11,19 @@ class TestData:
def __init__(self):
"""Initialize simple variables."""
# Test data directory
self.datadir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"testdata")
self.datadir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
# Flat image directory
self.snapshot_dir = os.path.join(self.datadir, "snapshot_dir")
# RGB image
filename_rgb = "setaria_small_plant_rgb.png"
self.small_rgb_img = os.path.join(self.datadir, filename_rgb)
self.small_rgb_img = os.path.join(self.datadir, "setaria_small_plant_rgb.png")
# Binary image
self.small_bin_img = os.path.join(self.datadir, "setaria_small_plant_mask.png")
# Text file with tuple coordinates (by group label)
self.pollen_coords = os.path.join(self.datadir, "points_file_import.coords")
# Kmeans Clustered Gray image
filename_kmeans = "silphium_seed_labeled_example.png"
self.kmeans_seed_gray_img = os.path.join(self.datadir, filename_kmeans)
self.kmeans_seed_gray_img = os.path.join(self.datadir, "silphium_seed_labeled_example.png")
# Small Hyperspectral image
filename_hyper = "corn-kernel-hyperspectral.raw"
self.envi_sample_data = os.path.join(self.datadir, filename_hyper)
self.envi_sample_data = os.path.join(self.datadir, "corn-kernel-hyperspectral.raw")


@pytest.fixture(scope="session")
Expand Down
12 changes: 12 additions & 0 deletions tests/test_annotate_get_centroids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import cv2
from plantcv.annotate.get_centroids import get_centroids


def test_get_centroids(test_data):
"""Test for PlantCV."""
# Read in test data
mask = cv2.imread(test_data.small_bin_img, -1)

coor = get_centroids(bin_img=mask)

assert coor == [(166, 214)]
Loading

0 comments on commit d029275

Please sign in to comment.