Skip to content

Commit

Permalink
Merge pull request #829 from danforthcenter/update_color_correction
Browse files Browse the repository at this point in the history
Update color correction
  • Loading branch information
nfahlgren authored Nov 4, 2021
2 parents 4d69e78 + 4680a34 commit dee77c9
Show file tree
Hide file tree
Showing 10 changed files with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions plantcv/plantcv/transform/color_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def get_color_matrix(rgb_img, mask):
if len(np.shape(mask)) != 2:
fatal_error("Input mask is not an gray-scale image.")

# convert to float and normalize to work with values between 0-1
rgb_img = rgb_img.astype(np.float64)/255

# create empty color_matrix
color_matrix = np.zeros((len(np.unique(mask))-1, 4))

Expand Down Expand Up @@ -205,6 +208,8 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
# split transformation_matrix
red, green, blue, red2, green2, blue2, red3, green3, blue3 = np.split(transformation_matrix, 9, 1)

# convert img to float to avoid integer overflow, normalize between 0-1
source_img = source_img.astype(np.float64)/255
# find linear, square, and cubic values of source_img color channels
source_b, source_g, source_r = cv2.split(source_img)
source_b2 = np.square(source_b)
Expand All @@ -226,9 +231,10 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
bgr = [b, g, r]
corrected_img = cv2.merge(bgr)

# round corrected_img elements to be within range and of the correct data type
corrected_img = np.rint(corrected_img)
corrected_img[np.where(corrected_img > 255)] = 255
# return values of the image to the 0-255 range
corrected_img = 255*np.clip(corrected_img, 0, 1)
corrected_img = np.floor(corrected_img)
# cast back to unsigned int
corrected_img = corrected_img.astype(np.uint8)

if params.debug == "print":
Expand All @@ -237,6 +243,8 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
elif params.debug == "plot":
# If debug is plot, print a horizontal view of source_img, corrected_img, and target_img to the plotting device
# plot horizontal comparison of source_img, corrected_img (with rounded elements) and target_img
# cast source_img back to unsigned int between 0-255 for visualization
source_img = (255*source_img).astype(np.uint8)
plot_image(np.hstack([source_img, corrected_img, target_img]))

# return corrected_img
Expand Down
Binary file modified tests/data/matrix_b1.npz
Binary file not shown.
Binary file modified tests/data/matrix_b2.npz
Binary file not shown.
Binary file modified tests/data/matrix_m1.npz
Binary file not shown.
Binary file modified tests/data/matrix_m2.npz
Binary file not shown.
Binary file modified tests/data/source1_matrix.npz
Binary file not shown.
Binary file modified tests/data/source2_matrix.npz
Binary file not shown.
Binary file modified tests/data/source_corrected.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/data/target_matrix.npz
Binary file not shown.
Binary file modified tests/data/transformation_matrix1.npz
Binary file not shown.

0 comments on commit dee77c9

Please sign in to comment.