From 8c5e20ab8aa1b474caab2033a18b7ea7d872c83f Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Mon, 28 Oct 2024 10:02:07 -0700 Subject: [PATCH] Implement alignment service (#313) * move stuff around to prep for alignment service * wip * Don't warp perspective on small segments either * initial class * Add `align` method to image alignment classes * lint * fix FourPointTransform to handle color images * Add tests against both transform backends * lints * Document new alignment function * Reorganize module structure a bit * add missing __init__ --- OCR/ocr/services/alignment/__init__.py | 1 + .../services/alignment/backends}/__init__.py | 0 .../backends}/four_point_transform.py | 19 ++++++++--- .../alignment/backends}/image_homography.py | 33 ++++++++++++++++--- .../backends}/random_perspective_transform.py | 0 OCR/ocr/services/alignment/image_alignment.py | 19 +++++++++++ OCR/tests/alignment_test.py | 13 +++++++- 7 files changed, 74 insertions(+), 11 deletions(-) create mode 100644 OCR/ocr/services/alignment/__init__.py rename OCR/{alignment => ocr/services/alignment/backends}/__init__.py (100%) rename OCR/{alignment => ocr/services/alignment/backends}/four_point_transform.py (74%) rename OCR/{alignment => ocr/services/alignment/backends}/image_homography.py (63%) rename OCR/{alignment => ocr/services/alignment/backends}/random_perspective_transform.py (100%) create mode 100644 OCR/ocr/services/alignment/image_alignment.py diff --git a/OCR/ocr/services/alignment/__init__.py b/OCR/ocr/services/alignment/__init__.py new file mode 100644 index 00000000..252e52f1 --- /dev/null +++ b/OCR/ocr/services/alignment/__init__.py @@ -0,0 +1 @@ +from .image_alignment import ImageAligner as ImageAligner diff --git a/OCR/alignment/__init__.py b/OCR/ocr/services/alignment/backends/__init__.py similarity index 100% rename from OCR/alignment/__init__.py rename to OCR/ocr/services/alignment/backends/__init__.py diff --git a/OCR/alignment/four_point_transform.py b/OCR/ocr/services/alignment/backends/four_point_transform.py similarity index 74% rename from OCR/alignment/four_point_transform.py rename to OCR/ocr/services/alignment/backends/four_point_transform.py index e73188f6..cb0954b4 100644 --- a/OCR/alignment/four_point_transform.py +++ b/OCR/ocr/services/alignment/backends/four_point_transform.py @@ -10,8 +10,15 @@ class FourPointTransform: - def __init__(self, image: Path): - self.image = cv.imread(str(image), cv.IMREAD_GRAYSCALE) + def __init__(self, image: Path | np.ndarray): + if isinstance(image, np.ndarray): + self.image = image + else: + self.image = cv.imread(str(image)) + + @classmethod + def align(self, source_image, template_image): + return FourPointTransform(source_image).dewarp() @staticmethod def _order_points(quadrilateral: np.ndarray) -> np.ndarray: @@ -28,7 +35,9 @@ def _order_points(quadrilateral: np.ndarray) -> np.ndarray: def find_largest_contour(self): """Compute contours for an image and find the biggest one by area.""" - _, contours, _ = cv.findContours(self.image, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE) + contours, _ = cv.findContours( + cv.cvtColor(self.image, cv.COLOR_BGR2GRAY), cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE + ) return functools.reduce(lambda a, b: b if cv.contourArea(a) < cv.contourArea(b) else a, contours) def simplify_polygon(self, contour): @@ -40,8 +49,8 @@ def dewarp(self) -> np.ndarray: biggest_contour = self.find_largest_contour() simplified = self.simplify_polygon(biggest_contour) - height, width = self.image.shape + height, width, _ = self.image.shape destination = np.array([[0, 0], [width, 0], [width, height], [0, height]], dtype=np.float32) - M = cv.getPerspectiveTransform(self.order_points(simplified), destination) + M = cv.getPerspectiveTransform(self._order_points(simplified), destination) return cv.warpPerspective(self.image, M, (width, height)) diff --git a/OCR/alignment/image_homography.py b/OCR/ocr/services/alignment/backends/image_homography.py similarity index 63% rename from OCR/alignment/image_homography.py rename to OCR/ocr/services/alignment/backends/image_homography.py index 2d91a5c1..5ab1ddfb 100644 --- a/OCR/alignment/image_homography.py +++ b/OCR/ocr/services/alignment/backends/image_homography.py @@ -5,15 +5,22 @@ class ImageHomography: - def __init__(self, template: Path, match_ratio=0.3): + def __init__(self, template: Path | np.ndarray, match_ratio=0.3): """Initialize the image homography pipeline with a `template` image.""" if match_ratio >= 1 or match_ratio <= 0: raise ValueError("`match_ratio` must be between 0 and 1") - self.template = cv.imread(template) + if isinstance(template, np.ndarray): + self.template = template + else: + self.template = cv.imread(template) self.match_ratio = match_ratio self._sift = cv.SIFT_create() + @classmethod + def align(self, source_image, template_image): + return ImageHomography(template_image).transform_homography(source_image) + def estimate_self_similarity(self): """Calibrate `match_ratio` using a self-similarity metric.""" raise NotImplementedError @@ -48,9 +55,25 @@ def estimate_transform_matrix(self, other): M, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 5.0) return M - def transform_homography(self, other, matrix=None): - """Run the image homography pipeline against a query image.""" + def transform_homography(self, other, min_axis=100, matrix=None): + """ + Run the image homography pipeline against a query image. + + Parameters: + min_axis: minimum x- and y-axis length, in pixels, to attempt to do a homography transform. + If the input image is under the axis limits, return the original input image unchanged. + matrix: if specified, a transformation matrix to warp the input image. Otherwise this will be + estimated with `estimate_transform_matrix`. + """ + + if other.shape[0] < min_axis and other.shape[1] < min_axis: + return other + if matrix is None: - matrix = self.estimate_transform_matrix(other) + try: + matrix = self.estimate_transform_matrix(other) + except cv.error: + print("could not estimate transform matrix") + return other return cv.warpPerspective(other, matrix, (self.template.shape[1], self.template.shape[0])) diff --git a/OCR/alignment/random_perspective_transform.py b/OCR/ocr/services/alignment/backends/random_perspective_transform.py similarity index 100% rename from OCR/alignment/random_perspective_transform.py rename to OCR/ocr/services/alignment/backends/random_perspective_transform.py diff --git a/OCR/ocr/services/alignment/image_alignment.py b/OCR/ocr/services/alignment/image_alignment.py new file mode 100644 index 00000000..d52472f4 --- /dev/null +++ b/OCR/ocr/services/alignment/image_alignment.py @@ -0,0 +1,19 @@ +import numpy as np + +from ocr.services.alignment.backends import ImageHomography + + +class ImageAligner: + def __init__(self, aligner=ImageHomography): + self.aligner = aligner + + def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray: + """ + Aligns an image using the specified image alignment backend. + + source_image: the image to be aligned, as a numpy ndarray. + template_image: the image that `source_image` will be aligned against, as a numpy ndarray. + May not be used for all image alignment backends. + """ + aligned_image = self.aligner.align(source_image, template_image) + return aligned_image diff --git a/OCR/tests/alignment_test.py b/OCR/tests/alignment_test.py index b1b7ea49..0ee70a0d 100644 --- a/OCR/tests/alignment_test.py +++ b/OCR/tests/alignment_test.py @@ -2,8 +2,10 @@ import cv2 as cv import numpy as np +import pytest -from alignment import ImageHomography, RandomPerspectiveTransform +from ocr.services.alignment.backends import FourPointTransform, ImageHomography, RandomPerspectiveTransform +from ocr.services.alignment import ImageAligner path = os.path.dirname(__file__) @@ -14,6 +16,15 @@ class TestAlignment: + @pytest.mark.parametrize("align_class", [ImageHomography, FourPointTransform]) + def test_align_implementation(self, align_class): + """Tests that the ImageAligner class backends implement the `align` method.""" + template_image = cv.imread(template_image_path) + aligner = ImageAligner(aligner=align_class) + result = aligner.align(filled_image, template_image) + assert result.shape == template_image.shape, "Aliged image has wrong shape" + assert np.median(cv.absdiff(template_image, result)) <= 1, "Median difference too high" + def test_random_warp(self): """Test that a random warp generates an image different from the template.""" transformed = RandomPerspectiveTransform(filled_image_path).random_transform(distortion_scale=0.1)