Skip to content

Commit

Permalink
Merge pull request #3 from ajkdrag/feature/add-doctr-fast-models
Browse files Browse the repository at this point in the history
[no ci] Updating reparameterize for Fast models
  • Loading branch information
ajkdrag authored Mar 29, 2024
2 parents 70892ac + a01a87e commit e5fe5ce
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/ocrtoolkit/integrations/doctr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from loguru import logger

from ocrtoolkit.utilities.model_utils import load_state_dict
from ocrtoolkit.utilities.model_utils import load_state_dict, reparameterize
from ocrtoolkit.wrappers.bbox import BBox
from ocrtoolkit.wrappers.detection_results import DetectionResults
from ocrtoolkit.wrappers.model import DetectionModel, RecognitionModel
Expand Down Expand Up @@ -58,7 +58,6 @@ def __init__(self, model, path, device, **kwargs):
from doctr.models.recognition.predictor import RecognitionPredictor

super().__init__(model, path, device)
kwargs.pop("pretrained_backbone", None)
kwargs.pop("vocab", None)
kwargs.pop("max_length", None)

Expand Down Expand Up @@ -111,6 +110,10 @@ def load(
)
if not pretrained:
load_state_dict(path, model)

if model_name.startswith("fast_"):
model = reparameterize(model)

return DoctrDetModel(model, path, device, **kwargs)
elif task == "rec":
from doctr.models import recognition
Expand Down
41 changes: 41 additions & 0 deletions src/ocrtoolkit/utilities/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,44 @@ def load_state_dict(path, model, ignore_keys: list = None):
raise ValueError("Failed to load state_dict. Non-matching keys.")
else:
model.load_state_dict(state_dict)


def reparameterize(model):
import torch

last_conv = None
last_conv_name = None

for module in model.modules():
if hasattr(module, "reparameterize_layer"):
module.reparameterize_layer()

for name, child in model.named_children():
if isinstance(child, torch.nn.BatchNorm2d):
# fuse batchnorm only if it is followed by a conv layer
if last_conv is None:
continue
conv_w = last_conv.weight
conv_b = (
last_conv.bias
if last_conv.bias is not None
else torch.zeros_like(child.running_mean)
)

factor = child.weight / torch.sqrt(child.running_var + child.eps)
last_conv.weight = torch.nn.Parameter(
conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1])
)
last_conv.bias = torch.nn.Parameter(
(conv_b - child.running_mean) * factor + child.bias
)
model._modules[last_conv_name] = last_conv
model._modules[name] = torch.nn.Identity()
last_conv = None
elif isinstance(child, torch.nn.Conv2d):
last_conv = child
last_conv_name = name
else:
reparameterize(child)

return model

0 comments on commit e5fe5ce

Please sign in to comment.