Skip to content

Commit

Permalink
Updated doctr latest release
Browse files Browse the repository at this point in the history
  • Loading branch information
ajkdrag committed Mar 6, 2024
1 parent 2ef5cde commit 0e2a052
Show file tree
Hide file tree
Showing 14 changed files with 441 additions and 110 deletions.
10 changes: 10 additions & 0 deletions extra-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# FORMAT
# Put your extra requirements here in the following format
#
# package[version_required]: tag1, tag2, ...

ultralytics==8.1.11: ultralytics
dill==0.3.8: ultralytics
paddleocr==2.7.0.3: paddle
paddlepaddle-gpu==2.6.0: paddle
python-doctr[torch]==0.8.1: doctr
224 changes: 137 additions & 87 deletions notebooks/experiments.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions notebooks/payee_name_extr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2396,9 +2396,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "ocrtoolkit",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "ocrtoolkit"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
4 changes: 2 additions & 2 deletions notebooks/prepping_ds_for_clf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2049,9 +2049,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "ocrtoolkit",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "ocrtoolkit"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ matplotlib
tqdm
loguru
h5py
scikit-learn
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ exclude =
# Provide a comma-separate list of glob patterns to include for checks.
filename =
*.py

[isort]
profile = black
36 changes: 25 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@

from setuptools import find_packages, setup


def get_extra_requires(path, add_all=True):
import re
from collections import defaultdict

with open(path) as fp:
extra_deps = defaultdict(set)
for k in fp:
if k.strip() and not k.startswith("#"):
tags = set()
if ":" in k:
k, v = k.split(":")
tags.update(vv.strip() for vv in v.split(","))
tags.add(re.split("[<=>]", k)[0])
for t in tags:
extra_deps[t].add(k)

# add tag `all` at the end
if add_all:
extra_deps["all"] = set(vv for v in extra_deps.values() for vv in v)

return extra_deps


regexp = re.compile(r".*__version__ = [\'\"](.*?)[\'\"]", re.S)

base_package = "ocrtoolkit"
Expand Down Expand Up @@ -53,17 +77,7 @@ def parse_requirements(filename):
maintainer_email="",
python_requires="==3.8.*",
install_requires=requirements,
extras_require={
"ultralytics": ["ultralytics==8.1.11"],
"paddle": ["paddleocr==2.7.0.3", "paddlepaddle-gpu==2.6.0"],
"doctr": ["python-doctr[torch]==0.7.0"],
"all": [
"ultralytics==8.1.11",
"python-doctr[torch]==0.7.0",
"paddleocr==2.7.0.3",
"paddlepaddle-gpu==2.6.0",
],
},
extras_require=get_extra_requires("extra-requirements.txt"),
keywords=["ocrtoolkit"],
package_dir={"": "src"},
packages=find_packages("src"),
Expand Down
54 changes: 54 additions & 0 deletions src/ocrtoolkit/models/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,24 @@ def load(path, device, model_kwargs, **kwargs):
)


class DOCTR_CRNN_MOBILENET_L(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load(
"rec", "crnn_mobilenet_v3_large", path, device, model_kwargs, **kwargs
)


class DOCTR_PARSEQ(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load("rec", "parseq", path, device, model_kwargs, **kwargs)


class DOCTR_DB_RESNET50(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
Expand All @@ -45,6 +63,42 @@ def load(path, device, model_kwargs, **kwargs):
)


class DOCTR_DB_RESNET34(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load(
"det", "db_resnet34", path, device, model_kwargs, **kwargs
)


class DOCTR_DB_MOBILENET_L(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load(
"det", "db_mobilenet_v3_large", path, device, model_kwargs, **kwargs
)


class DOCTR_FAST_TINY(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load("det", "fast_tiny", path, device, model_kwargs, **kwargs)


class DOCTR_FAST_SMALL(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load("det", "fast_small", path, device, model_kwargs, **kwargs)


class GCV_OCR(metaclass=BaseArch):
"""Google Cloud Vision OCR
Here `path` arg points to service account json file
Expand Down
3 changes: 3 additions & 0 deletions src/ocrtoolkit/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
from .io_utils import *
from .misc_utils import *
from .network_utils import *
from .geometry_utils import *
from .det_utils import *
from .model_utils import *
63 changes: 63 additions & 0 deletions src/ocrtoolkit/utilities/det_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
from typing import List, Tuple
from ocrtoolkit.utilities.geometry_utils import (
estimate_page_angle,
rotate_boxes,
)


def sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Sort bounding boxes from top to bottom, left to right."""
if boxes.ndim == 3: # Rotated boxes
angle = -estimate_page_angle(boxes)
boxes = rotate_boxes(
loc_preds=boxes, angle=angle, orig_shape=(1024, 1024), min_angle=5.0
)
boxes = np.concatenate((boxes.min(axis=1), boxes.max(axis=1)), axis=-1)
sort_indices = (
boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])
).argsort()
return sort_indices, boxes


def resolve_sub_lines(
boxes: np.ndarray, word_idcs: List[int], paragraph_break: float
) -> List[List[int]]:
"""Split a line in sub-lines."""
lines = []
word_idcs = sorted(word_idcs, key=lambda idx: boxes[idx, 0])

if len(word_idcs) < 2:
return [word_idcs]

sub_line = [word_idcs[0]]
for i in word_idcs[1:]:
if boxes[i, 0] - boxes[sub_line[-1], 2] < paragraph_break:
sub_line.append(i)
else:
lines.append(sub_line)
sub_line = [i]
lines.append(sub_line)
return lines


def resolve_lines(boxes: np.ndarray, paragraph_break: float) -> List[List[int]]:
"""Order boxes to group them in lines."""
idxs, boxes = sort_boxes(boxes)
y_med = np.median(boxes[:, 3] - boxes[:, 1])

lines, words, y_center_sum = [], [idxs[0]], boxes[idxs[0], [1, 3]].mean()
for idx in idxs[1:]:
y_dist = abs(boxes[idx, [1, 3]].mean() - y_center_sum / len(words))

if y_dist < y_med / 2:
words.append(idx)
y_center_sum += boxes[idx, [1, 3]].mean()
else:
lines.extend(resolve_sub_lines(boxes, words, paragraph_break))
words, y_center_sum = [idx], boxes[idx, [1, 3]].mean()

if words: # Process the last line
lines.extend(resolve_sub_lines(boxes, words, paragraph_break))

return lines
2 changes: 1 addition & 1 deletion src/ocrtoolkit/utilities/draw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from PIL import Image, ImageDraw, ImageFont

FONT_PATH = Path(__file__).parent.parent / "assets/Ubuntu-R.ttf"
FONT_PATH = Path(__file__).parent.parent.joinpath("assets/Ubuntu-R.ttf")


def draw_bbox(
Expand Down
135 changes: 135 additions & 0 deletions src/ocrtoolkit/utilities/geometry_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import numpy as np
from typing import Optional, Tuple


def estimate_page_angle(polys: np.ndarray) -> float:
"""Takes a batch of rotated previously
ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the
estimated angle ccw in degrees
"""
# Compute mean left points and mean right point
# with respect to the reading direction (oriented polygon)
xleft = polys[:, 0, 0] + polys[:, 3, 0]
yleft = polys[:, 0, 1] + polys[:, 3, 1]
xright = polys[:, 1, 0] + polys[:, 2, 0]
yright = polys[:, 1, 1] + polys[:, 2, 1]
with np.errstate(divide="raise", invalid="raise"):
try:
return float(
np.median(
np.arctan((yleft - yright) / (xright - xleft)) * 180 / np.pi
) # Y axis from top to bottom!
)
except FloatingPointError:
return 0.0


def remap_boxes(
loc_preds: np.ndarray, orig_shape: Tuple[int, int], dest_shape: Tuple[int, int]
) -> np.ndarray:
"""Remaps a batch of rotated locpred (N, 4, 2)
expressed for an origin_shape to a destination_shape.
This does not impact the absolute shape of the boxes,
but allow to calculate the new relative RotatedBbox
coordinates after a resizing of the image.
Args:
----
loc_preds: (N, 4, 2) array of RELATIVE loc_preds
orig_shape: shape of the origin image
dest_shape: shape of the destination image
Returns:
-------
A batch of rotated loc_preds (N, 4, 2) expressed in the destination referencial
"""
if len(dest_shape) != 2:
raise ValueError(f"Mask length should be 2, was found at: {len(dest_shape)}")
if len(orig_shape) != 2:
raise ValueError(
f"Image_shape length should be 2, was found at: {len(orig_shape)}"
)
orig_height, orig_width = orig_shape
dest_height, dest_width = dest_shape
mboxes = loc_preds.copy()
mboxes[:, :, 0] = (
(loc_preds[:, :, 0] * orig_width) + (dest_width - orig_width) / 2
) / dest_width
mboxes[:, :, 1] = (
(loc_preds[:, :, 1] * orig_height) + (dest_height - orig_height) / 2
) / dest_height

return mboxes


def rotate_boxes(
loc_preds: np.ndarray,
angle: float,
orig_shape: Tuple[int, int],
min_angle: float = 1.0,
target_shape: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
"""Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax, c)
or rotated bounding boxes
(4, 2) of an angle, if angle > min_angle, around the center of the page.
If target_shape is specified, the boxes are
remapped to the target shape after the rotation. This
is done to remove the padding that is created by rotate_page(expand=True)
Args:
----
loc_preds: (N, 5) or (N, 4, 2) array of RELATIVE boxes
angle: angle between -90 and +90 degrees
orig_shape: shape of the origin image
min_angle: minimum angle to rotate boxes
target_shape: shape of the destination image
Returns:
-------
A batch of rotated boxes (N, 4, 2): or a batch of straight bounding boxes
"""
# Change format of the boxes to rotated boxes
_boxes = loc_preds.copy()
if _boxes.ndim == 2:
_boxes = np.stack(
[
_boxes[:, [0, 1]],
_boxes[:, [2, 1]],
_boxes[:, [2, 3]],
_boxes[:, [0, 3]],
],
axis=1,
)
# If small angle, return boxes (no rotation)
if abs(angle) < min_angle or abs(angle) > 90 - min_angle:
return _boxes
# Compute rotation matrix
angle_rad = angle * np.pi / 180.0 # compute radian angle for np functions
rotation_mat = np.array(
[
[np.cos(angle_rad), -np.sin(angle_rad)],
[np.sin(angle_rad), np.cos(angle_rad)],
],
dtype=_boxes.dtype,
)
# Rotate absolute points
points: np.ndarray = np.stack(
(_boxes[:, :, 0] * orig_shape[1], _boxes[:, :, 1] * orig_shape[0]), axis=-1
)
image_center = (orig_shape[1] / 2, orig_shape[0] / 2)
rotated_points = image_center + np.matmul(points - image_center, rotation_mat)
rotated_boxes: np.ndarray = np.stack(
(
rotated_points[:, :, 0] / orig_shape[1],
rotated_points[:, :, 1] / orig_shape[0],
),
axis=-1,
)

# Apply a mask if requested
if target_shape is not None:
rotated_boxes = remap_boxes(
rotated_boxes, orig_shape=orig_shape, dest_shape=target_shape
)

return rotated_boxes
7 changes: 3 additions & 4 deletions src/ocrtoolkit/utilities/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import torch
def load_state_dict(path, model, ignore_keys: list = None):
import torch


def load_state_dict(path: str, model: torch.nn.Module, ignore_keys: list = None):
state_dict = torch.load(archive_path, map_location="cpu")
state_dict = torch.load(path, map_location="cpu")
if ignore_keys is not None and len(ignore_keys) > 0:
for key in ignore_keys:
state_dict.pop(key)
Expand Down
Loading

0 comments on commit 0e2a052

Please sign in to comment.