Skip to content

Commit

Permalink
make engine configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Jun 27, 2024
1 parent 495151c commit 59179b4
Show file tree
Hide file tree
Showing 24 changed files with 430 additions and 117 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-ast
- id: check-yaml
Expand All @@ -16,7 +16,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2
rev: v0.4.10
hooks:
- id: ruff
args: [ --fix ]
Expand Down
20 changes: 18 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Let's use the default `ocr_predictor` model for an example:

```python
from onnxtr.io import DocumentFile
from onnxtr.models import ocr_predictor
from onnxtr.models import ocr_predictor, EngineConfig

model = ocr_predictor(
det_arch='fast_base', # detection architecture
Expand All @@ -89,11 +89,15 @@ model = ocr_predictor(
detect_language=False, # set to `True` if the language of the pages should be detected (default: False)
# DocumentBuilder specific parameters
resolve_lines=True, # whether words should be automatically grouped into lines (default: True)
resolve_blocks=True, # whether lines should be automatically grouped into blocks (default: True)
resolve_blocks=False, # whether lines should be automatically grouped into blocks (default: False)
paragraph_break=0.035, # relative length of the minimum space separating paragraphs (default: 0.035)
# OnnxTR specific parameters
# NOTE: 8-Bit quantized models are not available for FAST detection models and can in general lead to poorer accuracy
load_in_8_bit=False, # set to `True` to load 8-bit quantized models instead of the full precision onces (default: False)
# Advanced engine configuration options
det_engine_cfg=EngineConfig(), # detection model engine configuration (default: internal predefined configuration)
reco_engine_cfg=EngineConfig(), # recognition model engine configuration (default: internal predefined configuration)
clf_engine_cfg=EngineConfig(), # classification (orientation) model engine configuration (default: internal predefined configuration)
)
# PDF
doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
Expand All @@ -103,6 +107,18 @@ result = model(doc)
result.show()
```

<details>

Check notice on line 110 in README.md

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

README.md#L110

Element: details
<summary>Advanced engine configuration options</summary>

Check notice on line 111 in README.md

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

README.md#L111

Element: summary

Install from source:

```bash
git clone https://github.com/Lightning-AI/litgpt
cd litgpt
pip install -e '.[all]'
```
</details>

![Visualization sample](https://github.com/felixdittrich92/OnnxTR/raw/main/docs/images/doctr_example_script.gif)

Or even rebuild the original document from its predictions:
Expand Down
1 change: 1 addition & 0 deletions onnxtr/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .engine import EngineConfig

Check notice on line 1 in onnxtr/models/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

onnxtr/models/__init__.py#L1

'.engine.EngineConfig' imported but unused (F401)
from .classification import *
from .detection import *
from .recognition import *
Expand Down
2 changes: 1 addition & 1 deletion onnxtr/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DocumentBuilder(NestedObject):
def __init__(
self,
resolve_lines: bool = True,
resolve_blocks: bool = True,
resolve_blocks: bool = False,
paragraph_break: float = 0.035,
export_as_straight_boxes: bool = False,
) -> None:
Expand Down
17 changes: 12 additions & 5 deletions onnxtr/models/classification/models/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from ...engine import Engine
from ...engine import Engine, EngineConfig

__all__ = [
"mobilenet_v3_small_crop_orientation",
Expand Down Expand Up @@ -43,17 +43,19 @@ class MobileNetV3(Engine):
Args:
----
model_path: path or url to onnx model file
engine_cfg: configuration for the inference engine
cfg: configuration dictionary
**kwargs: additional arguments to be passed to `Engine`
"""

def __init__(
self,
model_path: str,
engine_cfg: EngineConfig = EngineConfig(),
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
super().__init__(url=model_path, **kwargs)
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.cfg = cfg

def __call__(
Expand All @@ -67,17 +69,19 @@ def _mobilenet_v3(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> MobileNetV3:
# Patch the url
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
_cfg = deepcopy(default_cfgs[arch])
return MobileNetV3(model_path, cfg=_cfg, **kwargs)
return MobileNetV3(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)


def mobilenet_v3_small_crop_orientation(
model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
Expand All @@ -94,18 +98,20 @@ def mobilenet_v3_small_crop_orientation(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
-------
MobileNetV3
"""
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, **kwargs)
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)


def mobilenet_v3_small_page_orientation(
model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
Expand All @@ -122,10 +128,11 @@ def mobilenet_v3_small_page_orientation(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
-------
MobileNetV3
"""
return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, **kwargs)
return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
24 changes: 18 additions & 6 deletions onnxtr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from typing import Any, List

from onnxtr.models.engine import EngineConfig

from .. import classification
from ..preprocessor import PreProcessor
from .predictor import OrientationPredictor
Expand All @@ -14,12 +16,14 @@
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]


def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any) -> OrientationPredictor:
def _orientation_predictor(
arch: str, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any
) -> OrientationPredictor:
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

# Load directly classifier from backbone
_model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit)
_model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg)
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
Expand All @@ -32,7 +36,10 @@ def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any


def crop_orientation_predictor(
arch: Any = "mobilenet_v3_small_crop_orientation", load_in_8_bit: bool = False, **kwargs: Any
arch: Any = "mobilenet_v3_small_crop_orientation",
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> OrientationPredictor:
"""Crop orientation classification architecture.
Expand All @@ -46,17 +53,21 @@ def crop_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
load_in_8_bit: load the 8-bit quantized version of the model
engine_cfg: configuration of inference engine
**kwargs: keyword arguments to be passed to the OrientationPredictor
Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, load_in_8_bit, **kwargs)
return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs)


def page_orientation_predictor(
arch: Any = "mobilenet_v3_small_page_orientation", load_in_8_bit: bool = False, **kwargs: Any
arch: Any = "mobilenet_v3_small_page_orientation",
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> OrientationPredictor:
"""Page orientation classification architecture.
Expand All @@ -70,10 +81,11 @@ def page_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments to be passed to the OrientationPredictor
Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, load_in_8_bit, **kwargs)
return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
39 changes: 27 additions & 12 deletions onnxtr/models/detection/models/differentiable_binarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from scipy.special import expit

from ...engine import Engine
from ...engine import Engine, EngineConfig
from ..postprocessor.base import GeneralDetectionPostProcessor

__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
Expand All @@ -33,8 +33,8 @@
"input_shape": (3, 1024, 1024),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_mobilenet_v3_large-1866973f.onnx",
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_mobilenet_v3_large_static_8_bit-51659bb9.onnx",
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large-4987e7bd.onnx",
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large_static_8_bit-535a6f25.onnx",
},
}

Expand All @@ -45,6 +45,7 @@ class DBNet(Engine):
Args:
----
model_path: path or url to onnx model file
engine_cfg: configuration for the inference engine
bin_thresh: threshold for binarization of the output feature map
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
Expand All @@ -54,14 +55,15 @@ class DBNet(Engine):

def __init__(
self,
model_path,
model_path: str,
engine_cfg: EngineConfig = EngineConfig(),
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
super().__init__(url=model_path, **kwargs)
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.cfg = cfg
self.assume_straight_pages = assume_straight_pages
self.postprocessor = GeneralDetectionPostProcessor(
Expand Down Expand Up @@ -91,16 +93,20 @@ def _dbnet(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> DBNet:
# Patch the url
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
# Build the model
return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)
return DBNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)


def db_resnet34(
model_path: str = default_cfgs["db_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any
model_path: str = default_cfgs["db_resnet34"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-34 backbone.
Expand All @@ -115,17 +121,21 @@ def db_resnet34(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
return _dbnet("db_resnet34", model_path, load_in_8_bit, **kwargs)
return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)


def db_resnet50(
model_path: str = default_cfgs["db_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any
model_path: str = default_cfgs["db_resnet50"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
Expand All @@ -140,17 +150,21 @@ def db_resnet50(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
return _dbnet("db_resnet50", model_path, load_in_8_bit, **kwargs)
return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)


def db_mobilenet_v3_large(
model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, **kwargs: Any
model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
Expand All @@ -165,10 +179,11 @@ def db_mobilenet_v3_large(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs)
return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
Loading

0 comments on commit 59179b4

Please sign in to comment.