-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from ajkdrag/feature/add-doctr-fast-models
Feature/add doctr fast models
- Loading branch information
Showing
5 changed files
with
236 additions
and
193 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,127 +1,94 @@ | ||
class BaseArch(type): | ||
"""Base class for all architectures. | ||
If path is None, then model is pretrained. | ||
""" | ||
|
||
def __call__(cls, path=None, device="cpu", model_kwargs: dict = None, **kwargs): | ||
if model_kwargs is None: | ||
model_kwargs = {} | ||
return cls.load(path, device, model_kwargs, **kwargs) | ||
|
||
|
||
class UL_YOLOV8(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.ultralytics as framework | ||
|
||
return framework.load("yolov8", path, device, model_kwargs, **kwargs) | ||
|
||
|
||
class UL_RTDETR(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.ultralytics as framework | ||
|
||
return framework.load("rtdetr", path, device, model_kwargs, **kwargs) | ||
|
||
|
||
class DOCTR_CRNN_VGG16(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.doctr as framework | ||
|
||
return framework.load( | ||
"rec", "crnn_vgg16_bn", path, device, model_kwargs, **kwargs | ||
) | ||
|
||
|
||
class DOCTR_CRNN_MOBILENET_L(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.doctr as framework | ||
|
||
return framework.load( | ||
"rec", "crnn_mobilenet_v3_large", path, device, model_kwargs, **kwargs | ||
) | ||
|
||
|
||
class DOCTR_PARSEQ(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.doctr as framework | ||
|
||
return framework.load("rec", "parseq", path, device, model_kwargs, **kwargs) | ||
|
||
|
||
class DOCTR_DB_RESNET50(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.doctr as framework | ||
import importlib | ||
|
||
return framework.load( | ||
"det", "db_resnet50", path, device, model_kwargs, **kwargs | ||
) | ||
|
||
|
||
class DOCTR_DB_RESNET34(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.doctr as framework | ||
|
||
return framework.load( | ||
"det", "db_resnet34", path, device, model_kwargs, **kwargs | ||
) | ||
|
||
|
||
class DOCTR_DB_MOBILENET_L(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.doctr as framework | ||
|
||
return framework.load( | ||
"det", "db_mobilenet_v3_large", path, device, model_kwargs, **kwargs | ||
) | ||
|
||
|
||
class DOCTR_FAST_TINY(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.doctr as framework | ||
|
||
return framework.load("det", "fast_tiny", path, device, model_kwargs, **kwargs) | ||
|
||
|
||
class DOCTR_FAST_SMALL(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.doctr as framework | ||
|
||
return framework.load("det", "fast_small", path, device, model_kwargs, **kwargs) | ||
|
||
|
||
class GCV_OCR(metaclass=BaseArch): | ||
"""Google Cloud Vision OCR | ||
Here `path` arg points to service account json file | ||
""" | ||
|
||
@staticmethod | ||
def load(path, _, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.gcv as framework | ||
|
||
return framework.load(path, model_kwargs, **kwargs) | ||
|
||
class BaseArch(type): | ||
"""Base class for all architectures.""" | ||
|
||
class PPOCR_SVTR_LCNET(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.paddleocr as framework | ||
def __new__(cls, *args, **kwargs): | ||
return cls.load(*args, **kwargs) | ||
|
||
return framework.load("rec", "SVTR_LCNet", path, device, model_kwargs, **kwargs) | ||
|
||
class ArchitectureFactory: | ||
"""Factory class for creating architecture classes.""" | ||
|
||
class PPOCR_DBNET(metaclass=BaseArch): | ||
@staticmethod | ||
def load(path, device, model_kwargs, **kwargs): | ||
import ocrtoolkit.integrations.paddleocr as framework | ||
|
||
return framework.load("det", "DB", path, device, model_kwargs, **kwargs) | ||
def create_arch_class(class_name, framework_module, model_name=None, task=None): | ||
"""Create an architecture class dynamically.""" | ||
|
||
def load(path=None, device="cpu", model_kwargs=None, **kwargs): | ||
"""Load the model with the specified configuration.""" | ||
|
||
framework = importlib.import_module( | ||
f"ocrtoolkit.integrations.{framework_module}" | ||
) | ||
|
||
model_kwargs = model_kwargs or {} | ||
load_kwargs = {"path": path, "model_kwargs": model_kwargs, **kwargs} | ||
|
||
if class_name.startswith("UL_"): | ||
load_kwargs["model_name"] = model_name | ||
load_kwargs["device"] = device | ||
|
||
elif class_name.startswith("DOCTR_"): | ||
load_kwargs["model_name"] = model_name | ||
load_kwargs["task"] = task | ||
load_kwargs["device"] = device | ||
|
||
elif class_name.startswith("PPOCR_"): | ||
load_kwargs["model_name"] = model_name | ||
load_kwargs["task"] = task | ||
load_kwargs["device"] = device | ||
|
||
return framework.load(**load_kwargs) | ||
|
||
return type(class_name, (BaseArch,), {"load": staticmethod(load)}) | ||
|
||
|
||
factory = ArchitectureFactory() | ||
|
||
# ultralytics object detection | ||
UL_YOLOV8 = factory.create_arch_class("UL_YOLOV8", "ultralytics", "yolov8") | ||
UL_RTDETR = factory.create_arch_class("UL_RTDETR", "ultralytics", "rtdetr") | ||
|
||
# doctr recognition | ||
DOCTR_CRNN_VGG16 = factory.create_arch_class( | ||
"DOCTR_CRNN_VGG16", "doctr", "crnn_vgg16_bn", "rec" | ||
) | ||
DOCTR_CRNN_MOBILENET_L = factory.create_arch_class( | ||
"DOCTR_CRNN_MOBILENET_L", "doctr", "crnn_mobilenet_v3_large", "rec" | ||
) | ||
DOCTR_CRNN_MOBILENET_S = factory.create_arch_class( | ||
"DOCTR_CRNN_MOBILENET_S", "doctr", "crnn_mobilenet_v3_small", "rec" | ||
) | ||
DOCTR_PARSEQ = factory.create_arch_class("DOCTR_PARSEQ", "doctr", "parseq", "rec") | ||
DOCTR_VITSTR_S = factory.create_arch_class( | ||
"DOCTR_VITSTR_S", "doctr", "vitstr_small", "rec" | ||
) | ||
DOCTR_VITSTR_B = factory.create_arch_class( | ||
"DOCTR_VITSTR_B", "doctr", "vitstr_base", "rec" | ||
) | ||
|
||
# doctr detection | ||
DOCTR_DB_RESNET50 = factory.create_arch_class( | ||
"DOCTR_DB_RESNET50", "doctr", "db_resnet50", "det" | ||
) | ||
DOCTR_DB_RESNET34 = factory.create_arch_class( | ||
"DOCTR_DB_RESNET34", "doctr", "db_resnet34", "det" | ||
) | ||
DOCTR_DB_MOBILENET_L = factory.create_arch_class( | ||
"DOCTR_DB_MOBILENET_L", "doctr", "db_mobilenet_v3_large", "det" | ||
) | ||
DOCTR_FAST_T = factory.create_arch_class("DOCTR_FAST_T", "doctr", "fast_tiny", "det") | ||
DOCTR_FAST_S = factory.create_arch_class("DOCTR_FAST_S", "doctr", "fast_small", "det") | ||
DOCTR_FAST_B = factory.create_arch_class("DOCTR_FAST_B", "doctr", "fast_base", "det") | ||
|
||
# paddleocr recognition | ||
PPOCR_SVTR_LCNET = factory.create_arch_class( | ||
"PPOCR_SVTR_LCNET", "paddleocr", "SVTR_LCNet", "rec" | ||
) | ||
|
||
# paddleocr detection | ||
PPOCR_DBNET = factory.create_arch_class("PPOCR_DBNET", "paddleocr", "DB", "det") | ||
|
||
# gcv | ||
GCV_OCR = factory.create_arch_class("GCV_OCR", "gcv") |