Skip to content

Commit

Permalink
dev(narugo): add unittest and better docs
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Oct 16, 2023
1 parent 2d94ec3 commit f79a2bf
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 9 deletions.
21 changes: 21 additions & 0 deletions imgutils/ocr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,22 @@
"""
Overview:
Detect and recognize text in images.
The models are exported from `PaddleOCR <https://github.com/PaddlePaddle/PaddleOCR>`_, hosted on
`huggingface - deepghs/paddleocr <https://huggingface.co/deepghs/paddleocr/tree/main>`_.
.. image:: ocr_demo.plot.py.svg
:align: center
This is an overall benchmark of all the text detection models:
.. image:: ocr_det_benchmark.plot.py.svg
:align: center
and an overall benchmark of all the available text recognition models:
.. image:: ocr_rec_benchmark.plot.py.svg
:align: center
"""
from .entry import detect_text_with_ocr, ocr, list_det_models, list_rec_models
3 changes: 0 additions & 3 deletions imgutils/ocr/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ def _open_ocr_detection_model(model):
f'det/{model}/model.onnx',
))

print(ort.get_inputs()[0].shape)
return ort


def _box_score_fast(bitmap, _box):
h, w = bitmap.shape[:2]
Expand Down
60 changes: 56 additions & 4 deletions imgutils/ocr/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,53 @@
from .detect import _detect_text, _list_det_models
from .recognize import _text_recognize, _list_rec_models
from ..data import ImageTyping, load_image
from ..utils import tqdm

_DEFAULT_DET_MODEL = 'ch_PP-OCRv4_det'
_DEFAULT_REC_MODEL = 'ch_PP-OCRv4_rec'


def list_det_models() -> List[str]:
"""
List available text detection models for OCR.
:return: A list of available text detection model names.
:rtype: List[str]
"""
return _list_det_models()


def list_rec_models() -> List[str]:
"""
List available text recognition models for OCR.
:return: A list of available text recognition model names.
:rtype: List[str]
"""
return _list_rec_models()


def detect_text_with_ocr(image: ImageTyping, model: str = _DEFAULT_DET_MODEL,
heat_threshold: float = 0.3, box_threshold: float = 0.7,
max_candidates: int = 1000, unclip_ratio: float = 2.0) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Detect text in an image using an OCR model.
:param image: The input image.
:type image: ImageTyping
:param model: The name of the text detection model.
:type model: str, optional
:param heat_threshold: The heat map threshold for text detection.
:type heat_threshold: float, optional
:param box_threshold: The box threshold for text detection.
:type box_threshold: float, optional
:param max_candidates: The maximum number of candidates to consider.
:type max_candidates: int, optional
:param unclip_ratio: The unclip ratio for text detection.
:type unclip_ratio: float, optional
:return: A list of detected text boxes, their corresponding text content, and their confidence scores.
:rtype: List[Tuple[Tuple[int, int, int, int], str, float]]
"""
retval = []
for box, _, score in _detect_text(image, model, heat_threshold, box_threshold, max_candidates, unclip_ratio):
retval.append((box, 'text', score))
Expand All @@ -31,12 +60,35 @@ def detect_text_with_ocr(image: ImageTyping, model: str = _DEFAULT_DET_MODEL,
def ocr(image: ImageTyping, detect_model: str = _DEFAULT_DET_MODEL,
recognize_model: str = _DEFAULT_REC_MODEL, heat_threshold: float = 0.3, box_threshold: float = 0.7,
max_candidates: int = 1000, unclip_ratio: float = 2.0, rotation_threshold: float = 1.5,
is_remove_duplicate: bool = False, silent: bool = False):
is_remove_duplicate: bool = False):
"""
Perform optical character recognition (OCR) on an image.
:param image: The input image.
:type image: ImageTyping
:param detect_model: The name of the text detection model.
:type detect_model: str, optional
:param recognize_model: The name of the text recognition model.
:type recognize_model: str, optional
:param heat_threshold: The heat map threshold for text detection.
:type heat_threshold: float, optional
:param box_threshold: The box threshold for text detection.
:type box_threshold: float, optional
:param max_candidates: The maximum number of candidates to consider.
:type max_candidates: int, optional
:param unclip_ratio: The unclip ratio for text detection.
:type unclip_ratio: float, optional
:param rotation_threshold: The rotation threshold for text detection.
:type rotation_threshold: float, optional
:param is_remove_duplicate: Whether to remove duplicate text content.
:type is_remove_duplicate: bool, optional
:return: A list of detected text boxes, their corresponding text content, and their combined confidence scores.
:rtype: List[Tuple[Tuple[int, int, int, int], str, float]]
"""
image = load_image(image)
retval = []
for (x0, y0, x1, y1), _, score in \
tqdm(_detect_text(image, detect_model, heat_threshold,
box_threshold, max_candidates, unclip_ratio), silent=silent):
_detect_text(image, detect_model, heat_threshold, box_threshold, max_candidates, unclip_ratio):
width, height = x1 - x0, y1 - y0
area = image.crop((x0, y0, x1, y1))
if height >= width * rotation_threshold:
Expand Down
4 changes: 2 additions & 2 deletions imgutils/ocr/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _open_ocr_recognition_dictionary(model) -> List[str]:
return ['<blank>', *dict_, ' ']


def decode(text_index, model: str, text_prob=None, is_remove_duplicate=False):
def _text_decode(text_index, model: str, text_prob=None, is_remove_duplicate=False):
retval = []
ignored_tokens = [0]
batch_size = len(text_index)
Expand Down Expand Up @@ -76,7 +76,7 @@ def _text_recognize(image: ImageTyping, model: str = 'ch_PP-OCRv4_rec',

indices = output.argmax(axis=2)
confs = output.max(axis=2)
return decode(indices, model, confs, is_remove_duplicate)[0]
return _text_decode(indices, model, confs, is_remove_duplicate)[0]


@lru_cache()
Expand Down
Empty file added test/ocr/__init__.py
Empty file.
117 changes: 117 additions & 0 deletions test/ocr/test_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import pytest
from PIL import Image

from imgutils.ocr import detect_text_with_ocr, list_det_models, list_rec_models, ocr
from test.testings import get_testfile


@pytest.fixture()
def ocr_img_plot():
yield get_testfile('ocr', 'plot.png')


@pytest.fixture()
def ocr_img_plot_pil(ocr_img_plot):
yield Image.open(ocr_img_plot)


@pytest.fixture()
def ocr_img_comic():
yield get_testfile('ocr', 'comic.jpg')


@pytest.fixture()
def ocr_img_comic_pil(ocr_img_comic):
yield Image.open(ocr_img_comic)


@pytest.fixture()
def ocr_img_anime_subtitle():
yield get_testfile('ocr', 'anime_subtitle.jpg')


@pytest.fixture()
def ocr_img_anime_subtitle_pil(ocr_img_anime_subtitle):
yield Image.open(ocr_img_anime_subtitle)


@pytest.fixture()
def ocr_img_post_text():
yield get_testfile('ocr', 'post_text.jpg')


@pytest.fixture()
def ocr_img_post_text_pil(ocr_img_post_text):
yield Image.open(ocr_img_post_text)


@pytest.mark.unittest
class TestOcr:
def test_detect_text_with_ocr_comic(self, ocr_img_comic):
detections = detect_text_with_ocr(ocr_img_comic)
assert len(detections) == 8

values = []
for bbox, label, score in detections:
assert label == 'text'
values.append((bbox, int(score * 1000) / 1000))

assert values == pytest.approx([
((742, 485, 809, 511), 0.954),
((682, 98, 734, 124), 0.93),
((716, 136, 836, 164), 0.904),
((144, 455, 196, 485), 0.874),
((719, 455, 835, 488), 0.862),
((124, 478, 214, 508), 0.848),
((1030, 557, 1184, 578), 0.835),
((427, 129, 553, 154), 0.824)
])

def test_detect_text_with_ocr_anime_subtitle(self, ocr_img_anime_subtitle_pil):
detections = detect_text_with_ocr(ocr_img_anime_subtitle_pil)
assert len(detections) == 2

values = []
for bbox, label, score in detections:
assert label == 'text'
values.append((bbox, int(score * 1000) / 1000))

assert values == pytest.approx([
((312, 567, 690, 600), 0.817),
((332, 600, 671, 636), 0.798)
])

def test_list_det_models(self):
lst = list_det_models()
assert 'ch_PP-OCRv4_det' in lst
assert 'ch_ppocr_mobile_v2.0_det' in lst
assert 'en_PP-OCRv3_det' in lst

def test_ocr_comic(self, ocr_img_comic):
detections = ocr(ocr_img_comic)
assert len(detections) == 8

assert detections == pytest.approx([
((742, 485, 809, 511), 'MOB.', 0.9356705927336156),
((716, 136, 836, 164), 'SHISHOU,', 0.8933000384412466),
((682, 98, 734, 124), 'BUT', 0.8730931912907247),
((144, 455, 196, 485), 'OH,', 0.8417627579351514),
((427, 129, 553, 154), 'A MIRROR.', 0.7366019454049503),
((1030, 557, 1184, 578), '(EL) GATO IBERICO', 0.7271127306351021),
((719, 455, 835, 488), "THAt'S △", 0.701928390168364),
((124, 478, 214, 508), 'LOOK!', 0.6965972578194936),
], abs=1e-3)

def test_ocr_plot(self, ocr_img_plot):
detections = ocr(ocr_img_plot)
assert len(detections) >= 75

def test_list_rec_models(self):
lst = list_rec_models()
assert 'arabic_PP-OCRv3_rec' in lst
assert 'ch_PP-OCRv4_rec' in lst
assert 'ch_ppocr_mobile_v2.0_rec' in lst
assert 'japan_PP-OCRv3_rec' in lst
assert 'latin_PP-OCRv3_rec' in lst
assert 'korean_PP-OCRv3_rec' in lst
assert 'cyrillic_PP-OCRv3_rec' in lst
Binary file added test/testfile/ocr/anime_subtitle.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/testfile/ocr/comic.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/testfile/ocr/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/testfile/ocr/post_text.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit f79a2bf

Please sign in to comment.