Skip to content

Commit

Permalink
Merge branch 'main' into get_centroid
Browse files Browse the repository at this point in the history
  • Loading branch information
nfahlgren authored Apr 22, 2024
2 parents 95ab0e8 + 3a0d074 commit 48fd848
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.9', '3.10', '3.11']
os: [ubuntu-latest]
env:
OS: ${{ matrix.os }}
Expand Down
4 changes: 4 additions & 0 deletions plantcv/annotate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from importlib.metadata import version
from plantcv.annotate.classes import Points
from plantcv.annotate.get_centroids import get_centroids

# Auto versioning
__version__ = version("plantcv-annotate")

__all__ = [
"Points",
"get_centroids"
Expand Down
180 changes: 89 additions & 91 deletions plantcv/annotate/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,82 +9,49 @@
from plantcv.plantcv import warn


def _view(self, label="default", color="c", view_all=False):
"""
View the label for a specific class label
Inputs:
label = (optional) class label, by default label="total"
color = desired color, by default color="c"
view_all = indicator of whether view all classes, by default view_all=False
:param label: string
:param color: string
:param view_all: boolean
:return:
"""
if label not in self.coords and color in self.colors.values():
warn("The color assigned to the new class label is already used, if proceeding, "
"items from different classes will not be distinguishable in plots!")
if label is not None:
self.label = label
self.color = color
self.view_all = view_all

if label not in self.coords:
self.coords[self.label] = []
self.count[self.label] = 0
self.colors[self.label] = color

self.fig, self.ax = plt.subplots(1, 1, figsize=self.figsize)

self.events = []
self.fig.canvas.mpl_connect('button_press_event', self.onclick)

self.ax.imshow(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
self.ax.set_title("Please left click on objects\n Right click to remove")
self.p_not_current = 0
# if view_all is True, show all already marked markers
if view_all:
for k in self.coords:
for (x, y) in self.coords[k]:
self.ax.plot(x, y, marker='x', c=self.colors[k])
if self.label not in self.coords or len(self.coords[self.label]) == 0:
self.p_not_current += 1
else:
for (x, y) in self.coords[self.label]:
self.ax.plot(x, y, marker='x', c=color)


class Points:
"""Point annotation/collection class to use in Jupyter notebooks. It allows the user to
interactively click to collect coordinates from an image. Left click collects the point and
right click removes the closest collected point
right click removes the closest collected point.
"""

def __init__(self, img, figsize=(12, 6), label="default"):
"""Initialization
:param img: image data
:param figsize: desired figure size, (12,6) by default
:param label: current label for group of annotations, similar to pcv.params.sample_label
:attribute coords: list of points as (x,y) coordinates tuples
def __init__(self, img, figsize=(12, 6), label="default", color="r", view_all=False):
"""Points initialization method.
Parameters
----------
img : numpy.ndarray
image to annotate
figsize : tuple, optional
figure plotting size, by default (12, 6)
label : str, optional
class label, by default "default"
"""
self.img = img
self.figsize = figsize
self.label = label # current label
self.color = color # current color
self.view_all = view_all # a flag indicating whether or not view all labels
self.coords = {} # dictionary of all coordinates per group label
self.events = [] # includes right and left click events
self.count = {} # a dictionary that saves the counts of different groups (labels)
self.label = label # current label
self.sample_labels = [] # list of all sample labels, one to one with points collected
self.view_all = None # a flag indicating whether or not view all labels
self.color = None # current color
self.colors = {} # all used colors
self.figsize = figsize

_view(self, label=label, color="r", view_all=True)
self.view(label=self.label, color=self.color, view_all=self.view_all)

def onclick(self, event):
"""Handle mouse click events."""
"""Handle mouse click events
Parameters
----------
event : matplotlib.backend_bases.MouseEvent
matplotlib MouseEvent object
"""
print(type(event))
self.events.append(event)
if event.button == 1:

# Add point to the plot
self.ax.plot(event.xdata, event.ydata, marker='x', c=self.color)
self.coords[self.label].append((floor(event.xdata), floor(event.ydata)))
self.count[self.label] += 1
Expand All @@ -99,43 +66,45 @@ def onclick(self, event):
self.sample_labels.pop(idx_remove)
self.fig.canvas.draw()

def print_coords(self, outfile):
def print_coords(self, filename):
"""Save collected coordinates to a file.
Input variables:
outfile = Name of the file to save collected coordinate
:param filename: str
:return:
Parameters
----------
filename : str
output filename
"""
# Open the file for writing
with open(outfile, "w") as fp:
with open(filename, "w") as fp:
# Save the data in JSON format with indentation
json.dump(obj=self.coords, fp=fp, indent=4)

def import_list(self, coords, label="default"):
"""Import center coordinates of already detected objects
Inputs:
coords = list of center coordinates of already detected objects.
label = class label for imported coordinates, by default label="default".
:param coords: list
:param label: string
:return:
"""Import coordinates.
Parameters
----------
coords : list
list of coordinates (tuples)
label : str, optional
class label, by default "default"
"""
if label not in self.coords:
self.coords[label] = []
for (y, x) in coords:
for (x, y) in coords:
self.coords[label].append((x, y))
self.count[label] = len(self.coords[label])
_view(self, label=label, color=self.color, view_all=False)
self.view(label=label, color=self.color, view_all=False)
else:
warn(f"{label} already included and counted, nothing is imported!")

def import_file(self, filename):
"""Method to import coordinates from file to Points object
"""Import coordinates from a file.
Inputs:
filename = filename of stored coordinates and classes
:param filename: str
:return:
Parameters
----------
filename : str
JSON file containing Points annotations
"""
with open(filename, "r") as fp:
coords = json.load(fp)
Expand All @@ -147,16 +116,45 @@ def import_file(self, filename):
keycoor = list(map(lambda sub: (sub[1], sub[0]), keycoor))
self.import_list(keycoor, label=key)

def view(self, label="default", color="c", view_all=False):
"""Method to view current annotations
Inputs:
label = (optional) class label, by default label="total"
color = desired color, by default color="c"
view_all = indicator of whether view all classes, by default view_all=False
:param label: string
:param color: string
:param view_all: boolean
:return:
def view(self, label="default", color="r", view_all=False):
"""View coordinates for a specific class label.
Parameters
----------
label : str, optional
class label, by default "default"
color : str, optional
marker color, by default "r"
view_all : bool, optional
view all classes or a single class, by default False
"""
_view(self, label=label, color=color, view_all=view_all)
if label not in self.coords and color in self.colors.values():
warn("The color assigned to the new class label is already used, if proceeding, "
"items from different classes will not be distinguishable in plots!")
self.label = label
self.color = color
self.view_all = view_all

if self.label not in self.coords:
self.coords[self.label] = []
self.count[self.label] = 0
self.colors[self.label] = self.color

self.fig, self.ax = plt.subplots(1, 1, figsize=self.figsize)

self.events = []
self.fig.canvas.mpl_connect('button_press_event', self.onclick)

self.ax.imshow(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
self.ax.set_title("Please left click on objects\n Right click to remove")
self.p_not_current = 0
# if view_all is True, show all already marked markers
if self.view_all:
for k in self.coords:
for (x, y) in self.coords[k]:
self.ax.plot(x, y, marker='x', c=self.colors[k])
if self.label not in self.coords or len(self.coords[self.label]) == 0:
self.p_not_current += 1
else:
for (x, y) in self.coords[self.label]:
self.ax.plot(x, y, marker='x', c=self.color)
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools >= 61.0"]
requires = ["setuptools >= 64.0", "setuptools_scm>=8"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
Expand All @@ -26,7 +26,15 @@ classifiers = [
"Intended Audience :: Science/Research",
]

[project.optional-dependencies]
test = [
"pytest",
"pytest-cov",
]

[project.urls]
Homepage = "https://plantcv.org"
Documentation = "https://plantcv.readthedocs.io"
Repository = "https://github.com/danforthcenter/plantcv-annotate"

[tool.setuptools_scm]
2 changes: 1 addition & 1 deletion tests/test_annotate_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_points_print_coords(test_data, tmpdir):
drawer_rgb.onclick(e2)

# Save collected coords out
drawer_rgb.print_coords(outfile=filename)
drawer_rgb.print_coords(filename)
assert os.path.exists(filename)

def test_points_import_list(test_data):
Expand Down

0 comments on commit 48fd848

Please sign in to comment.