Skip to content

Commit

Permalink
Implement alignment service (#313)
Browse files Browse the repository at this point in the history
* move stuff around to prep for alignment service

* wip

* Don't warp perspective on small segments either

* initial class

* Add `align` method to image alignment classes

* lint

* fix FourPointTransform to handle color images

* Add tests against both transform backends

* lints

* Document new alignment function

* Reorganize module structure a bit

* add missing __init__
  • Loading branch information
jonchang authored Oct 28, 2024
1 parent a459c15 commit 8c5e20a
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 11 deletions.
1 change: 1 addition & 0 deletions OCR/ocr/services/alignment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .image_alignment import ImageAligner as ImageAligner
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@


class FourPointTransform:
def __init__(self, image: Path):
self.image = cv.imread(str(image), cv.IMREAD_GRAYSCALE)
def __init__(self, image: Path | np.ndarray):
if isinstance(image, np.ndarray):
self.image = image
else:
self.image = cv.imread(str(image))

@classmethod
def align(self, source_image, template_image):
return FourPointTransform(source_image).dewarp()

@staticmethod
def _order_points(quadrilateral: np.ndarray) -> np.ndarray:
Expand All @@ -28,7 +35,9 @@ def _order_points(quadrilateral: np.ndarray) -> np.ndarray:

def find_largest_contour(self):
"""Compute contours for an image and find the biggest one by area."""
_, contours, _ = cv.findContours(self.image, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
contours, _ = cv.findContours(
cv.cvtColor(self.image, cv.COLOR_BGR2GRAY), cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE
)
return functools.reduce(lambda a, b: b if cv.contourArea(a) < cv.contourArea(b) else a, contours)

def simplify_polygon(self, contour):
Expand All @@ -40,8 +49,8 @@ def dewarp(self) -> np.ndarray:
biggest_contour = self.find_largest_contour()
simplified = self.simplify_polygon(biggest_contour)

height, width = self.image.shape
height, width, _ = self.image.shape
destination = np.array([[0, 0], [width, 0], [width, height], [0, height]], dtype=np.float32)

M = cv.getPerspectiveTransform(self.order_points(simplified), destination)
M = cv.getPerspectiveTransform(self._order_points(simplified), destination)
return cv.warpPerspective(self.image, M, (width, height))
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@


class ImageHomography:
def __init__(self, template: Path, match_ratio=0.3):
def __init__(self, template: Path | np.ndarray, match_ratio=0.3):
"""Initialize the image homography pipeline with a `template` image."""
if match_ratio >= 1 or match_ratio <= 0:
raise ValueError("`match_ratio` must be between 0 and 1")

self.template = cv.imread(template)
if isinstance(template, np.ndarray):
self.template = template
else:
self.template = cv.imread(template)
self.match_ratio = match_ratio
self._sift = cv.SIFT_create()

@classmethod
def align(self, source_image, template_image):
return ImageHomography(template_image).transform_homography(source_image)

def estimate_self_similarity(self):
"""Calibrate `match_ratio` using a self-similarity metric."""
raise NotImplementedError
Expand Down Expand Up @@ -48,9 +55,25 @@ def estimate_transform_matrix(self, other):
M, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 5.0)
return M

def transform_homography(self, other, matrix=None):
"""Run the image homography pipeline against a query image."""
def transform_homography(self, other, min_axis=100, matrix=None):
"""
Run the image homography pipeline against a query image.
Parameters:
min_axis: minimum x- and y-axis length, in pixels, to attempt to do a homography transform.
If the input image is under the axis limits, return the original input image unchanged.
matrix: if specified, a transformation matrix to warp the input image. Otherwise this will be
estimated with `estimate_transform_matrix`.
"""

if other.shape[0] < min_axis and other.shape[1] < min_axis:
return other

if matrix is None:
matrix = self.estimate_transform_matrix(other)
try:
matrix = self.estimate_transform_matrix(other)
except cv.error:
print("could not estimate transform matrix")
return other

return cv.warpPerspective(other, matrix, (self.template.shape[1], self.template.shape[0]))
19 changes: 19 additions & 0 deletions OCR/ocr/services/alignment/image_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

from ocr.services.alignment.backends import ImageHomography


class ImageAligner:
def __init__(self, aligner=ImageHomography):
self.aligner = aligner

def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray:
"""
Aligns an image using the specified image alignment backend.
source_image: the image to be aligned, as a numpy ndarray.
template_image: the image that `source_image` will be aligned against, as a numpy ndarray.
May not be used for all image alignment backends.
"""
aligned_image = self.aligner.align(source_image, template_image)
return aligned_image
13 changes: 12 additions & 1 deletion OCR/tests/alignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import cv2 as cv
import numpy as np
import pytest

from alignment import ImageHomography, RandomPerspectiveTransform
from ocr.services.alignment.backends import FourPointTransform, ImageHomography, RandomPerspectiveTransform
from ocr.services.alignment import ImageAligner


path = os.path.dirname(__file__)
Expand All @@ -14,6 +16,15 @@


class TestAlignment:
@pytest.mark.parametrize("align_class", [ImageHomography, FourPointTransform])
def test_align_implementation(self, align_class):
"""Tests that the ImageAligner class backends implement the `align` method."""
template_image = cv.imread(template_image_path)
aligner = ImageAligner(aligner=align_class)
result = aligner.align(filled_image, template_image)
assert result.shape == template_image.shape, "Aliged image has wrong shape"
assert np.median(cv.absdiff(template_image, result)) <= 1, "Median difference too high"

def test_random_warp(self):
"""Test that a random warp generates an image different from the template."""
transformed = RandomPerspectiveTransform(filled_image_path).random_transform(distortion_scale=0.1)
Expand Down

0 comments on commit 8c5e20a

Please sign in to comment.