Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add basic functionality for annotating Points #3

Merged
merged 39 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
80b1cd3
add _find_closest and onclick
HaleySchuhl Feb 29, 2024
7d0e0e8
test suite setup
HaleySchuhl Mar 1, 2024
9325e4d
add test data
HaleySchuhl Mar 1, 2024
bfbeee3
import _find_closest_pt from pcv
HaleySchuhl Mar 1, 2024
9f3ef42
update test name and import statement
HaleySchuhl Mar 1, 2024
144312e
trailing whitespace and unused import
HaleySchuhl Mar 4, 2024
c53dfa7
add line return at end of file
HaleySchuhl Mar 4, 2024
3b02f56
add line return at end of file
HaleySchuhl Mar 4, 2024
a985065
refactor points attribute to coords
HaleySchuhl Mar 4, 2024
c2fb999
attribute name in tests
HaleySchuhl Mar 4, 2024
915ff0d
move test file into correct subdir
HaleySchuhl Mar 4, 2024
530f65e
Update conftest.py
HaleySchuhl Mar 4, 2024
2eb4289
move testdata into tests/
HaleySchuhl Mar 4, 2024
fe2f12f
delete extra test structure
HaleySchuhl Mar 4, 2024
6418f44
Merge branch 'main' into add_points
HaleySchuhl Mar 4, 2024
8c08ace
populate init
HaleySchuhl Mar 4, 2024
0c95c4b
add viewing method, update onclick and add print_coords
HaleySchuhl Mar 4, 2024
3ead6b6
add test for save_coords
HaleySchuhl Mar 4, 2024
c16ffe4
delete whitespace, other deepsource, and add import_list method
HaleySchuhl Mar 4, 2024
2037760
add test for importing lists
HaleySchuhl Mar 4, 2024
230b6df
add import_file method
HaleySchuhl Mar 4, 2024
1ba7f4d
add test data file
HaleySchuhl Mar 4, 2024
3c7c609
add test for importing from file
HaleySchuhl Mar 4, 2024
52169cf
Merge branch 'main' into add_points
nfahlgren Mar 4, 2024
da2443c
add test for warning triggered in import_file
HaleySchuhl Mar 5, 2024
d2606c9
actually add method .view
HaleySchuhl Mar 5, 2024
9b2689f
add tests for untested code
HaleySchuhl Mar 5, 2024
8c26942
Merge branch 'add_points' of https://github.com/danforthcenter/plantc…
HaleySchuhl Mar 5, 2024
fa7aeb8
plot updated annotations after importing file and list
HaleySchuhl Mar 5, 2024
664c3dd
add tests for full coverage
HaleySchuhl Mar 5, 2024
cd48f84
deepsource issues
HaleySchuhl Mar 5, 2024
6540e27
deepsource issue about reading and writing files
HaleySchuhl Mar 5, 2024
2cabecd
whitespace
HaleySchuhl Mar 5, 2024
b276a17
change input variable name back to "filename"
HaleySchuhl Mar 6, 2024
240e9ca
Merge branch 'main' into add_points
HaleySchuhl Mar 12, 2024
bd56de0
Merge branch 'main' into add_points
nfahlgren Apr 17, 2024
fad8ed4
Reformat docstrings to numpy format
nfahlgren Apr 18, 2024
f500605
Simplify code and remove private funciton
nfahlgren Apr 18, 2024
8ff04cc
Fix x and y assignments
nfahlgren Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions plantcv/annotate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from importlib.metadata import version
from plantcv.annotate.classes import Points

# Auto versioning
__version__ = version("plantcv-annotate")

__all__ = [
"Points"
]
160 changes: 160 additions & 0 deletions plantcv/annotate/classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Class helpers

# Imports
import cv2
import json
from math import floor
import matplotlib.pyplot as plt
from plantcv.plantcv.annotate.points import _find_closest_pt
from plantcv.plantcv import warn


class Points:
"""Point annotation/collection class to use in Jupyter notebooks. It allows the user to
interactively click to collect coordinates from an image. Left click collects the point and
right click removes the closest collected point.
"""

def __init__(self, img, figsize=(12, 6), label="default", color="r", view_all=False):
"""Points initialization method.

Parameters
----------
img : numpy.ndarray
image to annotate
figsize : tuple, optional
figure plotting size, by default (12, 6)
label : str, optional
class label, by default "default"
"""
self.img = img
self.figsize = figsize
self.label = label # current label
self.color = color # current color
self.view_all = view_all # a flag indicating whether or not view all labels
self.coords = {} # dictionary of all coordinates per group label
self.events = [] # includes right and left click events
self.count = {} # a dictionary that saves the counts of different groups (labels)
self.sample_labels = [] # list of all sample labels, one to one with points collected
self.colors = {} # all used colors

self.view(label=self.label, color=self.color, view_all=self.view_all)

def onclick(self, event):
"""Handle mouse click events

Parameters
----------
event : matplotlib.backend_bases.MouseEvent
matplotlib MouseEvent object
"""
print(type(event))
self.events.append(event)
if event.button == 1:
# Add point to the plot
self.ax.plot(event.xdata, event.ydata, marker='x', c=self.color)
self.coords[self.label].append((floor(event.xdata), floor(event.ydata)))
self.count[self.label] += 1
self.sample_labels.append(self.label)
else:
idx_remove, _ = _find_closest_pt((event.xdata, event.ydata), self.coords[self.label])
# remove the closest point to the user right clicked one
self.coords[self.label].pop(idx_remove)
self.count[self.label] -= 1
idx_remove = idx_remove + self.p_not_current
self.ax.lines[idx_remove].remove()
self.sample_labels.pop(idx_remove)
self.fig.canvas.draw()

def print_coords(self, filename):
"""Save collected coordinates to a file.

Parameters
----------
filename : str
output filename
"""
# Open the file for writing
with open(filename, "w") as fp:
# Save the data in JSON format with indentation
json.dump(obj=self.coords, fp=fp, indent=4)

def import_list(self, coords, label="default"):
"""Import coordinates.

Parameters
----------
coords : list
list of coordinates (tuples)
label : str, optional
class label, by default "default"
"""
if label not in self.coords:
self.coords[label] = []
for (x, y) in coords:
self.coords[label].append((x, y))
self.count[label] = len(self.coords[label])
self.view(label=label, color=self.color, view_all=False)
else:
warn(f"{label} already included and counted, nothing is imported!")

def import_file(self, filename):
"""Import coordinates from a file.

Parameters
----------
filename : str
JSON file containing Points annotations
"""
with open(filename, "r") as fp:
coords = json.load(fp)

keys = list(coords.keys())

for key in keys:
keycoor = coords[key]
keycoor = list(map(lambda sub: (sub[1], sub[0]), keycoor))
self.import_list(keycoor, label=key)

def view(self, label="default", color="r", view_all=False):
"""View coordinates for a specific class label.

Parameters
----------
label : str, optional
class label, by default "default"
color : str, optional
marker color, by default "r"
view_all : bool, optional
view all classes or a single class, by default False
"""
if label not in self.coords and color in self.colors.values():
warn("The color assigned to the new class label is already used, if proceeding, "
"items from different classes will not be distinguishable in plots!")
self.label = label
self.color = color
self.view_all = view_all

if self.label not in self.coords:
self.coords[self.label] = []
self.count[self.label] = 0
self.colors[self.label] = self.color

self.fig, self.ax = plt.subplots(1, 1, figsize=self.figsize)

self.events = []
self.fig.canvas.mpl_connect('button_press_event', self.onclick)

self.ax.imshow(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
self.ax.set_title("Please left click on objects\n Right click to remove")
self.p_not_current = 0
# if view_all is True, show all already marked markers
if self.view_all:
for k in self.coords:
for (x, y) in self.coords[k]:
self.ax.plot(x, y, marker='x', c=self.colors[k])
if self.label not in self.coords or len(self.coords[self.label]) == 0:
self.p_not_current += 1
else:
for (x, y) in self.coords[self.label]:
self.ax.plot(x, y, marker='x', c=self.color)
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
import os
import pytest
import matplotlib

# Disable plotting
matplotlib.use("Template")


class TestData:
def __init__(self):
"""Initialize simple variables."""
# Test data directory
self.datadir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
# Flat image directory
self.snapshot_dir = os.path.join(self.datadir, "snapshot_dir")
# RGB image
self.small_rgb_img = os.path.join(self.datadir, "setaria_small_plant_rgb.png")
# Text file with tuple coordinates (by group label)
self.pollen_coords = os.path.join(self.datadir, "points_file_import.coords")

@pytest.fixture(scope="session")
def test_data():
"""Test data object for the main PlantCV package."""
return TestData()
152 changes: 152 additions & 0 deletions tests/test_annotate_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Tests for annotate.Points."""
import os
import cv2
import matplotlib
from plantcv.annotate.classes import Points


def test_points(test_data):
"""Test for plantcv-annotate."""
# Read in a test grayscale image
img = cv2.imread(test_data.small_rgb_img)

# initialize interactive tool
drawer_rgb = Points(img, figsize=(12, 6))

# simulate mouse clicks
# event 1, left click to add point
e1 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=drawer_rgb.fig.canvas,
x=0, y=0, button=1)
point1 = (200, 200)
e1.xdata, e1.ydata = point1
drawer_rgb.onclick(e1)

# event 2, left click to add point
e2 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=drawer_rgb.fig.canvas,
x=0, y=0, button=1)
e2.xdata, e2.ydata = (300, 200)
drawer_rgb.onclick(e2)

# event 3, left click to add point
e3 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=drawer_rgb.fig.canvas,
x=0, y=0, button=1)
e3.xdata, e3.ydata = (50, 50)
drawer_rgb.onclick(e3)

# event 4, right click to remove point with exact coordinates
e4 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=drawer_rgb.fig.canvas,
x=0, y=0, button=3)
e4.xdata, e4.ydata = (50, 50)
drawer_rgb.onclick(e4)

# event 5, right click to remove point with coordinates close but not equal
e5 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=drawer_rgb.fig.canvas,
x=0, y=0, button=3)
e5.xdata, e5.ydata = (301, 200)
drawer_rgb.onclick(e5)

assert drawer_rgb.coords["default"][0] == point1

def test_points_print_coords(test_data, tmpdir):
"""Test for plantcv-annotate."""
cache_dir = tmpdir.mkdir("cache")
filename = os.path.join(cache_dir, 'plantcv_print_coords.txt')
# Read in a test image
img = cv2.imread(test_data.small_rgb_img)

# initialize interactive tool
drawer_rgb = Points(img, figsize=(12, 6))

# simulate mouse clicks
# event 1, left click to add point
e1 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=drawer_rgb.fig.canvas,
x=0, y=0, button=1)
point1 = (200, 200)
e1.xdata, e1.ydata = point1
drawer_rgb.onclick(e1)

# event 2, left click to add point
e2 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=drawer_rgb.fig.canvas,
x=0, y=0, button=1)
e2.xdata, e2.ydata = (300, 200)
drawer_rgb.onclick(e2)

# Save collected coords out
drawer_rgb.print_coords(filename)
assert os.path.exists(filename)

def test_points_import_list(test_data):
"""Test for plantcv-annotate."""
# Read in a test image
img = cv2.imread(test_data.small_rgb_img)
# initialize interactive tool
drawer_rgb = Points(img, figsize=(12, 6), label="default")
totalpoints1 = [(158, 531), (361, 112), (500, 418), (269.25303806488864, 385.69839981447126),
(231.21964288863632, 445.995245825603), (293.37177646934134, 448.778177179963), (240.49608073650273, 277.1640769944342),
(279.4571196975417, 240.05832560296852), (77.23077461405376, 165.84682282003712), (420, 364),
(509.5127783246289, 353.2308673469388), (527.1380102355752, 275.3087894248609), (445.50535717435065, 138.94515306122452)]
drawer_rgb.import_list(coords=totalpoints1, label="imported")

assert len(drawer_rgb.coords["imported"]) == 13

def test_points_import_list_warn(test_data):
"""Test for plantcv-annotate."""
# Read in a test image
img = cv2.imread(test_data.small_rgb_img)
# initialize interactive tool
drawer_rgb = Points(img, figsize=(12, 6), label="default")
totalpoints1 = [(158, 531), (361, 112), (500, 418), (445.50535717435065, 138.94515306122452)]
drawer_rgb.import_list(coords=totalpoints1)

assert len(drawer_rgb.coords["default"]) == 0

def test_points_import_file(test_data):
"""Test for plantcv-annotate."""
img = cv2.imread(test_data.small_rgb_img)
counter = Points(img, figsize=(8, 6))
file = test_data.pollen_coords
counter.import_file(file)

assert counter.count['total'] == 70

def test_points_view(test_data):
"""Test for plantcv-annotate."""
# Read in a test grayscale image
img = cv2.imread(test_data.small_rgb_img)

# initialize interactive tool
drawer_rgb = Points(img, figsize=(12, 6))

# simulate mouse clicks
# event 1, left click to add point
e1 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=drawer_rgb.fig.canvas,
x=0, y=0, button=1)
point1 = (200, 200)
e1.xdata, e1.ydata = point1
drawer_rgb.onclick(e1)
drawer_rgb.view(label="new", view_all=True)
e2 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=drawer_rgb.fig.canvas,
x=0, y=0, button=1)
e2.xdata, e2.ydata = (300, 200)
drawer_rgb.onclick(e2)
drawer_rgb.view(view_all=False)

assert str(drawer_rgb.fig) == "Figure(1200x600)"

def test_points_view_warn(test_data):
"""Test for plantcv-annotate."""
# Read in a test grayscale image
img = cv2.imread(test_data.small_rgb_img)

# initialize interactive tool, implied default label and "r" color
drawer_rgb = Points(img, figsize=(12, 6))

# simulate mouse clicks, event 1=left click to add point
e1 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=drawer_rgb.fig.canvas,
x=0, y=0, button=1)
point1 = (200, 200)
e1.xdata, e1.ydata = point1
drawer_rgb.onclick(e1)
drawer_rgb.view(label="new", color='r')

assert str(drawer_rgb.fig) == "Figure(1200x600)"
Loading