Skip to content

Commit

Permalink
Merge pull request #4 from ajkdrag/feature/add-doctr-fast-models
Browse files Browse the repository at this point in the history
Feature/add doctr fast models
  • Loading branch information
ajkdrag authored Mar 30, 2024
2 parents e5fe5ce + 4abfc7e commit ba01225
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 193 deletions.
10 changes: 5 additions & 5 deletions extra-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#
# package[version_required]: tag1, tag2, ...

ultralytics==8.1.11: ultralytics
dill==0.3.8: ultralytics
paddleocr==2.7.0.3: paddle
paddlepaddle-gpu==2.6.0: paddle
python-doctr[torch]==0.8.1: doctr
ultralytics==8.1.11: ultralytics
dill==0.3.8: ultralytics
paddleocr==2.7.0.3: paddle
paddlepaddle-gpu==2.6.0: paddle
python-doctr[torch] @ git+https://github.com/mindee/doctr.git@8c85c3654e4ae0a045a990d6f23973bc26d3483c: doctr
129 changes: 99 additions & 30 deletions notebooks/combining_different_frameworks.ipynb

Large diffs are not rendered by default.

76 changes: 42 additions & 34 deletions notebooks/working_with_models.ipynb

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


def get_extra_requires(path, add_all=True):
import re
from collections import defaultdict

with open(path) as fp:
Expand All @@ -14,9 +13,9 @@ def get_extra_requires(path, add_all=True):
if k.strip() and not k.startswith("#"):
tags = set()
if ":" in k:
k, v = k.split(":")
k, v = k.rsplit(":", 1)
tags.update(vv.strip() for vv in v.split(","))
tags.add(re.split("[<=>]", k)[0])
# tags.add(re.split("[<=>]", k)[0])
for t in tags:
extra_deps[t].add(k)

Expand Down Expand Up @@ -61,7 +60,7 @@ def parse_requirements(filename):


requirements = parse_requirements("requirements.txt")

print(get_extra_requires("extra-requirements.txt"))

if __name__ == "__main__":
setup(
Expand Down
207 changes: 87 additions & 120 deletions src/ocrtoolkit/models/arch.py
Original file line number Diff line number Diff line change
@@ -1,127 +1,94 @@
class BaseArch(type):
"""Base class for all architectures.
If path is None, then model is pretrained.
"""

def __call__(cls, path=None, device="cpu", model_kwargs: dict = None, **kwargs):
if model_kwargs is None:
model_kwargs = {}
return cls.load(path, device, model_kwargs, **kwargs)


class UL_YOLOV8(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.ultralytics as framework

return framework.load("yolov8", path, device, model_kwargs, **kwargs)


class UL_RTDETR(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.ultralytics as framework

return framework.load("rtdetr", path, device, model_kwargs, **kwargs)


class DOCTR_CRNN_VGG16(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load(
"rec", "crnn_vgg16_bn", path, device, model_kwargs, **kwargs
)


class DOCTR_CRNN_MOBILENET_L(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load(
"rec", "crnn_mobilenet_v3_large", path, device, model_kwargs, **kwargs
)


class DOCTR_PARSEQ(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load("rec", "parseq", path, device, model_kwargs, **kwargs)


class DOCTR_DB_RESNET50(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework
import importlib

return framework.load(
"det", "db_resnet50", path, device, model_kwargs, **kwargs
)


class DOCTR_DB_RESNET34(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load(
"det", "db_resnet34", path, device, model_kwargs, **kwargs
)


class DOCTR_DB_MOBILENET_L(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load(
"det", "db_mobilenet_v3_large", path, device, model_kwargs, **kwargs
)


class DOCTR_FAST_TINY(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load("det", "fast_tiny", path, device, model_kwargs, **kwargs)


class DOCTR_FAST_SMALL(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.doctr as framework

return framework.load("det", "fast_small", path, device, model_kwargs, **kwargs)


class GCV_OCR(metaclass=BaseArch):
"""Google Cloud Vision OCR
Here `path` arg points to service account json file
"""

@staticmethod
def load(path, _, model_kwargs, **kwargs):
import ocrtoolkit.integrations.gcv as framework

return framework.load(path, model_kwargs, **kwargs)

class BaseArch(type):
"""Base class for all architectures."""

class PPOCR_SVTR_LCNET(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.paddleocr as framework
def __new__(cls, *args, **kwargs):
return cls.load(*args, **kwargs)

return framework.load("rec", "SVTR_LCNet", path, device, model_kwargs, **kwargs)

class ArchitectureFactory:
"""Factory class for creating architecture classes."""

class PPOCR_DBNET(metaclass=BaseArch):
@staticmethod
def load(path, device, model_kwargs, **kwargs):
import ocrtoolkit.integrations.paddleocr as framework

return framework.load("det", "DB", path, device, model_kwargs, **kwargs)
def create_arch_class(class_name, framework_module, model_name=None, task=None):
"""Create an architecture class dynamically."""

def load(path=None, device="cpu", model_kwargs=None, **kwargs):
"""Load the model with the specified configuration."""

framework = importlib.import_module(
f"ocrtoolkit.integrations.{framework_module}"
)

model_kwargs = model_kwargs or {}
load_kwargs = {"path": path, "model_kwargs": model_kwargs, **kwargs}

if class_name.startswith("UL_"):
load_kwargs["model_name"] = model_name
load_kwargs["device"] = device

elif class_name.startswith("DOCTR_"):
load_kwargs["model_name"] = model_name
load_kwargs["task"] = task
load_kwargs["device"] = device

elif class_name.startswith("PPOCR_"):
load_kwargs["model_name"] = model_name
load_kwargs["task"] = task
load_kwargs["device"] = device

return framework.load(**load_kwargs)

return type(class_name, (BaseArch,), {"load": staticmethod(load)})


factory = ArchitectureFactory()

# ultralytics object detection
UL_YOLOV8 = factory.create_arch_class("UL_YOLOV8", "ultralytics", "yolov8")
UL_RTDETR = factory.create_arch_class("UL_RTDETR", "ultralytics", "rtdetr")

# doctr recognition
DOCTR_CRNN_VGG16 = factory.create_arch_class(
"DOCTR_CRNN_VGG16", "doctr", "crnn_vgg16_bn", "rec"
)
DOCTR_CRNN_MOBILENET_L = factory.create_arch_class(
"DOCTR_CRNN_MOBILENET_L", "doctr", "crnn_mobilenet_v3_large", "rec"
)
DOCTR_CRNN_MOBILENET_S = factory.create_arch_class(
"DOCTR_CRNN_MOBILENET_S", "doctr", "crnn_mobilenet_v3_small", "rec"
)
DOCTR_PARSEQ = factory.create_arch_class("DOCTR_PARSEQ", "doctr", "parseq", "rec")
DOCTR_VITSTR_S = factory.create_arch_class(
"DOCTR_VITSTR_S", "doctr", "vitstr_small", "rec"
)
DOCTR_VITSTR_B = factory.create_arch_class(
"DOCTR_VITSTR_B", "doctr", "vitstr_base", "rec"
)

# doctr detection
DOCTR_DB_RESNET50 = factory.create_arch_class(
"DOCTR_DB_RESNET50", "doctr", "db_resnet50", "det"
)
DOCTR_DB_RESNET34 = factory.create_arch_class(
"DOCTR_DB_RESNET34", "doctr", "db_resnet34", "det"
)
DOCTR_DB_MOBILENET_L = factory.create_arch_class(
"DOCTR_DB_MOBILENET_L", "doctr", "db_mobilenet_v3_large", "det"
)
DOCTR_FAST_T = factory.create_arch_class("DOCTR_FAST_T", "doctr", "fast_tiny", "det")
DOCTR_FAST_S = factory.create_arch_class("DOCTR_FAST_S", "doctr", "fast_small", "det")
DOCTR_FAST_B = factory.create_arch_class("DOCTR_FAST_B", "doctr", "fast_base", "det")

# paddleocr recognition
PPOCR_SVTR_LCNET = factory.create_arch_class(
"PPOCR_SVTR_LCNET", "paddleocr", "SVTR_LCNet", "rec"
)

# paddleocr detection
PPOCR_DBNET = factory.create_arch_class("PPOCR_DBNET", "paddleocr", "DB", "det")

# gcv
GCV_OCR = factory.create_arch_class("GCV_OCR", "gcv")

0 comments on commit ba01225

Please sign in to comment.