Skip to content

Commit

Permalink
refactor looping method in correct_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
HaleySchuhl committed Sep 18, 2024
1 parent 08e14bc commit e757806
Showing 1 changed file with 110 additions and 105 deletions.
215 changes: 110 additions & 105 deletions plantcv/annotate/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ def correct_mask(self, mask):
final_mask = np.zeros(np.shape(mask), np.uint32)
debug_img = np.zeros(np.shape(mask), np.uint8)
debug_img_duplicates = debug_img.copy()
pts_all = sum(self.coords.values(), [])
labels_all = []
for pt in pts_all:
coord_class_label = [k for k, v in self.coords.items() if pt in v]
labels_all.append(coord_class_label)

bin_mask = np.where(mask > 0, 255, 0).astype(np.uint8)
pts_mask = self._create_pts_mask(bin_mask, labelnames)
Expand All @@ -228,121 +233,121 @@ def correct_mask(self, mask):
# Initialize object count
object_id_count = 1
# pts in class used for recovering and labeling
for names in labelnames:
for (x, y) in self.coords[names]:
x = int(x)
y = int(y)
mask_pixel_value = labeled_mask_all[y, x]
# Check if current annotation can be resolved to an object in the mask
if mask_pixel_value == 0:
warn(f"Object could not be resolved at coordinate: x = {x}, y = {y}")
unrecovered_ids.append(object_id_count)
added_obj_labels.append(object_id_count)
for p, current_pt in enumerate(pts_all):
x = int(current_pt[1])
y = int(current_pt[0])
names = labels_all[p]
mask_pixel_value = labeled_mask_all[y, x]
# Check if current annotation can be resolved to an object in the mask
if mask_pixel_value == 0:
warn(f"Object could not be resolved at coordinate: x = {x}, y = {y}")
unrecovered_ids.append(object_id_count)
added_obj_labels.append(object_id_count)
analysis_labels.append(names)
# Add info to label object IDs in debug img
debug_labels, debug_coords = _add_debug_id(debug_labels, debug_coords,
object_id_count, (x, y))
# Add the unresolved object to the labeled mask and the debug img
debug_img, final_mask, object_id_count = _draw_unresolved_object(debug_img,
final_mask,
obj_number=object_id_count,
coord=(x, y))
if mask_pixel_value > 0:
# An object is resolved but check if there are other annotations associated with an object
mask_pixel_index = np.where(keep_pixel_vals == mask_pixel_value)[0]
associated_count = keep_object_count[mask_pixel_index]
if associated_count == 1:
# New object getting added
added_obj_labels.append(mask_pixel_value)
analysis_labels.append(names)
# Add info to label object IDs in debug img
debug_labels, debug_coords = _add_debug_id(debug_labels, debug_coords,
object_id_count, (x, y))
# Add the unresolved object to the labeled mask and the debug img
debug_img, final_mask, object_id_count = _draw_unresolved_object(debug_img,
final_mask,
obj_number=object_id_count,
coord=(x, y))
if mask_pixel_value > 0:
# An object is resolved but check if there are other annotations associated with an object
mask_pixel_index = np.where(keep_pixel_vals == mask_pixel_value)[0]
associated_count = keep_object_count[mask_pixel_index]
if associated_count == 1:
# New object getting added
added_obj_labels.append(mask_pixel_value)
analysis_labels.append(names)
debug_labels, debug_coords = _add_debug_id(debug_labels, debug_coords,
object_id_count, (x, y))
# Draw on labeled mask and debug img
debug_img, final_mask, object_id_count = _draw_resolved(debug_img, final_mask, labeled_mask_all,
mask_pixel_value, object_id_count)
if associated_count > 1:
# Has this object been handled already?
if mask_pixel_value not in added_obj_labels:
# Object annotated more than once so find all associated annotations
associated_coords = np.where(masked_image2 == mask_pixel_value)
associated_coords = tuple(zip(*associated_coords))
first_coord = (associated_coords[0][1], associated_coords[0][0])
coord_labels = []
# Find all class labels for each annotation
object_id_count, (x, y))
# Draw on labeled mask and debug img
debug_img, final_mask, object_id_count = _draw_resolved(debug_img, final_mask, labeled_mask_all,
mask_pixel_value, object_id_count)
if associated_count > 1:
# Has this object been handled already?
if mask_pixel_value not in added_obj_labels:
# Object annotated more than once so find all associated annotations
associated_coords = np.where(masked_image2 == mask_pixel_value)
associated_coords = tuple(zip(*associated_coords))
first_coord = (associated_coords[0][1], associated_coords[0][0])
coord_labels = []
# Find all class labels for each annotation
for dup_coord in associated_coords:
# Flip x & y for numpy, and find the associated class label with each coordinate
coord_class_label = [k for k, v in self.coords.items() if (dup_coord[1], dup_coord[0]) in v]
coord_labels.append(coord_class_label)
# Is there more than one class label associated with the given object?
re = np.unique(coord_labels)
if len(re) == 1:
# Labels are duplicated e.g. "total", "total"
# Draw the ghost of objects removed
debug_img_duplicates = np.where(labeled_mask_all == mask_pixel_value,
(255), debug_img_duplicates)
# Fill in the duplicate object in the labeled mask, replace with pixel annotations
final_mask = np.where(labeled_mask_all == mask_pixel_value, (0), final_mask)
added_obj_labels.append(mask_pixel_value)
for dup_coord in associated_coords:
# Flip x & y for numpy, and find the associated class label with each coordinate
coord_class_label = [k for k, v in self.coords.items() if (dup_coord[1], dup_coord[0]) in v]
coord_labels.append(coord_class_label)
# Is there more than one class label associated with the given object?
re = np.unique(coord_labels)
if len(re) == 1:
# Labels are duplicated e.g. "total", "total"
# Draw each pixel in the final mask
final_mask[dup_coord] = object_id_count
analysis_labels.append(names)
# Add a thicker pixel where unresolved annotation to the debug img
cv2.circle(debug_img, (dup_coord[1], dup_coord[0]), radius=params.line_thickness,
color=(object_id_count), thickness=-1)
# Add debug label annotations later
debug_labels, debug_coords = _add_debug_id(debug_labels, debug_coords,
object_id_count, (dup_coord[1], dup_coord[0]))
# Increment object count up so each pixel drawn in labeled mask is unique
object_id_count += 1
if len(re) > 1:
# More than one class label associated with a given object
splitup = []
# Split on "_" in case something has already been combined
for lbls in coord_labels:
list_lbl = []
for lbl in lbls:
list_lbl.append(lbl.split("_"))
splitup.append(np.concatenate(list_lbl))
# Flatten list of labels
flat = np.concatenate(splitup)
# Grab each unique label from the list
unique_lbls, lbl_counts = np.unique(flat, return_counts=True)
# Is there duplication within each class label for the given object?
if np.all(lbl_counts == 1):
# If no, Concat with "_" delimiter
concat_lbl = "_".join(list(unique_lbls))
warn(f"labels getting concatenated to '{concat_lbl}' at {first_coord}")
# Adding the object
added_obj_labels.append(mask_pixel_value)
analysis_labels.append(concat_lbl)
# Add debug label annotations later
debug_labels, debug_coords = _add_debug_id(debug_labels, debug_coords,
object_id_count, first_coord)
# Draw on labeled mask and debug img
debug_img, final_mask, object_id_count = _draw_resolved(
debug_img, final_mask, labeled_mask_all, mask_pixel_value, object_id_count)
else:
# e.g. "total", "total", "germinated" is too complex to measure
warn(f"The object at {first_coord} was removed for being too complex. "
"It was associated with the following labels: {flat1}")
added_obj_labels.append(mask_pixel_value)
# Draw the ghost of objects removed
debug_img_duplicates = np.where(labeled_mask_all == mask_pixel_value,
(255), debug_img_duplicates)
# Fill in the duplicate object in the labeled mask, replace with pixel annotations
# Fill in the duplicate object in the labeled mask
final_mask = np.where(labeled_mask_all == mask_pixel_value, (0), final_mask)
added_obj_labels.append(mask_pixel_value)
for dup_coord in associated_coords:
# Draw each pixel in the final mask
# ADD PIXEL ANNOTATIONS TO FINAL MASK AND TO DEBUG ?
for i, dup_coord in enumerate(associated_coords):
final_mask[dup_coord] = object_id_count
analysis_labels.append(names)
# Add a thicker pixel where unresolved annotation to the debug img
cv2.circle(debug_img, (dup_coord[1], dup_coord[0]), radius=params.line_thickness,
color=(object_id_count), thickness=-1)
# Add debug label annotations later
analysis_labels.append(coord_labels[i])
cv2.circle(debug_img, (dup_coord[1], dup_coord[0]),
radius=params.line_thickness, color=(object_id_count),
thickness=-1)
debug_labels, debug_coords = _add_debug_id(debug_labels, debug_coords,
object_id_count, (dup_coord[1], dup_coord[0]))
# Increment object count up so each pixel drawn in labeled mask is unique
object_id_count,
(dup_coord[1], dup_coord[0]))
object_id_count += 1
if len(re) > 1:
# More than one class label associated with a given object
splitup = []
# Split on "_" in case something has already been combined
for lbls in coord_labels:
list_lbl = []
for lbl in lbls:
list_lbl.append(lbl.split("_"))
splitup.append(np.concatenate(list_lbl))
# Flatten list of labels
flat = np.concatenate(splitup)
# Grab each unique label from the list
unique_lbls, lbl_counts = np.unique(flat, return_counts=True)
# Is there duplication within each class label for the given object?
if np.all(lbl_counts == 1):
# If no, Concat with "_" delimiter
concat_lbl = "_".join(list(unique_lbls))
warn(f"labels getting concatenated to '{concat_lbl}' at {first_coord}")
# Adding the object
added_obj_labels.append(mask_pixel_value)
analysis_labels.append(concat_lbl)
# Add debug label annotations later
debug_labels, debug_coords = _add_debug_id(debug_labels, debug_coords,
object_id_count, first_coord)
# Draw on labeled mask and debug img
debug_img, final_mask, object_id_count = _draw_resolved(
debug_img, final_mask, labeled_mask_all, mask_pixel_value, object_id_count)
else:
# e.g. "total", "total", "germinated" is too complex to measure
warn(f"The object at {first_coord} was removed for being too complex. "
"It was associated with the following labels: {flat1}")
added_obj_labels.append(mask_pixel_value)
# Draw the ghost of objects removed
debug_img_duplicates = np.where(labeled_mask_all == mask_pixel_value,
(255), debug_img_duplicates)
# Fill in the duplicate object in the labeled mask
final_mask = np.where(labeled_mask_all == mask_pixel_value, (0), final_mask)
# ADD PIXEL ANNOTATIONS TO FINAL MASK AND TO DEBUG ?
for i, dup_coord in enumerate(associated_coords):
final_mask[dup_coord] = object_id_count
analysis_labels.append(coord_labels[i])
cv2.circle(debug_img, (dup_coord[1], dup_coord[0]),
radius=params.line_thickness, color=(object_id_count),
thickness=-1)
debug_labels, debug_coords = _add_debug_id(debug_labels, debug_coords,
object_id_count,
(dup_coord[1], dup_coord[0]))
object_id_count += 1
# Combine and colorize components of the debug image
debug_img_duplicates_rgb = _draw_ghost_of_duplicates_removed(debug_img_duplicates)
debug_img = colorize_label_img(debug_img)
Expand Down

0 comments on commit e757806

Please sign in to comment.