Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Engine] make engine configurable #19

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .conda/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{% set pyproject = load_file_data('../pyproject.toml', from_recipe_dir=True) %}
{% set project = pyproject.get('project') %}
{% set urls = pyproject.get('project', {}).get('urls') %}
{% set version = environ.get('BUILD_VERSION', '0.2.1a0') %}
{% set version = environ.get('BUILD_VERSION', '0.3.0a0') %}

package:
name: onnxtr
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-ast
- id: check-yaml
Expand All @@ -16,7 +16,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2
rev: v0.4.10
hooks:
- id: ruff
args: [ --fix ]
Expand Down
41 changes: 39 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

```python
from onnxtr.io import DocumentFile
from onnxtr.models import ocr_predictor
from onnxtr.models import ocr_predictor, EngineConfig

model = ocr_predictor(
det_arch='fast_base', # detection architecture
Expand All @@ -89,11 +89,15 @@
detect_language=False, # set to `True` if the language of the pages should be detected (default: False)
# DocumentBuilder specific parameters
resolve_lines=True, # whether words should be automatically grouped into lines (default: True)
resolve_blocks=True, # whether lines should be automatically grouped into blocks (default: True)
resolve_blocks=False, # whether lines should be automatically grouped into blocks (default: False)
paragraph_break=0.035, # relative length of the minimum space separating paragraphs (default: 0.035)
# OnnxTR specific parameters
# NOTE: 8-Bit quantized models are not available for FAST detection models and can in general lead to poorer accuracy
load_in_8_bit=False, # set to `True` to load 8-bit quantized models instead of the full precision onces (default: False)
# Advanced engine configuration options
det_engine_cfg=EngineConfig(), # detection model engine configuration (default: internal predefined configuration)
reco_engine_cfg=EngineConfig(), # recognition model engine configuration (default: internal predefined configuration)
clf_engine_cfg=EngineConfig(), # classification (orientation) model engine configuration (default: internal predefined configuration)
)
# PDF
doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
Expand All @@ -103,6 +107,39 @@
result.show()
```

<details>

Check notice on line 110 in README.md

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

README.md#L110

Element: details
<summary>Advanced engine configuration options</summary>

Check notice on line 111 in README.md

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

README.md#L111

Element: summary

You can also define advanced engine configurations for the models / predictors:

```python
from onnxruntime import SessionOptions

from onnxtr.models import ocr_predictor, EngineConfig

general_options = SessionOptions() # For configuartion options see: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
general_options.enable_cpu_mem_arena = False

# NOTE: The following would force to run only on the GPU if no GPU is available it will raise an error
# List of strings e.g. ["CUDAExecutionProvider", "CPUExecutionProvider"] or a list of tuples with the provider and its options e.g.
# [("CUDAExecutionProvider", {"device_id": 0}), ("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})]
providers = [("CUDAExecutionProvider", {"device_id": 0})] # For available providers see: https://onnxruntime.ai/docs/execution-providers/

engine_config = EngineConfig(
session_options=general_options,
providers=providers
)
# We use the default predictor with the custom engine configuration
# NOTE: You can define differnt engine configurations for detection, recognition and classification depending on your needs
predictor = ocr_predictor(
det_engine_cfg=engine_config,
reco_engine_cfg=engine_config,
clf_engine_cfg=engine_config
)
```

</details>

![Visualization sample](https://github.com/felixdittrich92/OnnxTR/raw/main/docs/images/doctr_example_script.gif)

Or even rebuild the original document from its predictions:
Expand Down
1 change: 1 addition & 0 deletions onnxtr/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .engine import EngineConfig

Check notice on line 1 in onnxtr/models/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

onnxtr/models/__init__.py#L1

'.engine.EngineConfig' imported but unused (F401)
from .classification import *
from .detection import *
from .recognition import *
Expand Down
2 changes: 1 addition & 1 deletion onnxtr/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DocumentBuilder(NestedObject):
def __init__(
self,
resolve_lines: bool = True,
resolve_blocks: bool = True,
resolve_blocks: bool = False,
paragraph_break: float = 0.035,
export_as_straight_boxes: bool = False,
) -> None:
Expand Down
17 changes: 12 additions & 5 deletions onnxtr/models/classification/models/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from ...engine import Engine
from ...engine import Engine, EngineConfig

__all__ = [
"mobilenet_v3_small_crop_orientation",
Expand Down Expand Up @@ -43,17 +43,19 @@ class MobileNetV3(Engine):
Args:
----
model_path: path or url to onnx model file
engine_cfg: configuration for the inference engine
cfg: configuration dictionary
**kwargs: additional arguments to be passed to `Engine`
"""

def __init__(
self,
model_path: str,
engine_cfg: EngineConfig = EngineConfig(),
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
super().__init__(url=model_path, **kwargs)
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.cfg = cfg

def __call__(
Expand All @@ -67,17 +69,19 @@ def _mobilenet_v3(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> MobileNetV3:
# Patch the url
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
_cfg = deepcopy(default_cfgs[arch])
return MobileNetV3(model_path, cfg=_cfg, **kwargs)
return MobileNetV3(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)


def mobilenet_v3_small_crop_orientation(
model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
Expand All @@ -94,18 +98,20 @@ def mobilenet_v3_small_crop_orientation(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the MobileNetV3 architecture

Returns:
-------
MobileNetV3
"""
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, **kwargs)
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)


def mobilenet_v3_small_page_orientation(
model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
Expand All @@ -122,10 +128,11 @@ def mobilenet_v3_small_page_orientation(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the MobileNetV3 architecture

Returns:
-------
MobileNetV3
"""
return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, **kwargs)
return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
24 changes: 18 additions & 6 deletions onnxtr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from typing import Any, List

from onnxtr.models.engine import EngineConfig

from .. import classification
from ..preprocessor import PreProcessor
from .predictor import OrientationPredictor
Expand All @@ -14,12 +16,14 @@
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]


def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any) -> OrientationPredictor:
def _orientation_predictor(
arch: str, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any
) -> OrientationPredictor:
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

# Load directly classifier from backbone
_model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit)
_model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg)
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
Expand All @@ -32,7 +36,10 @@ def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any


def crop_orientation_predictor(
arch: Any = "mobilenet_v3_small_crop_orientation", load_in_8_bit: bool = False, **kwargs: Any
arch: Any = "mobilenet_v3_small_crop_orientation",
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> OrientationPredictor:
"""Crop orientation classification architecture.

Expand All @@ -46,17 +53,21 @@ def crop_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
load_in_8_bit: load the 8-bit quantized version of the model
engine_cfg: configuration of inference engine
**kwargs: keyword arguments to be passed to the OrientationPredictor

Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, load_in_8_bit, **kwargs)
return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs)


def page_orientation_predictor(
arch: Any = "mobilenet_v3_small_page_orientation", load_in_8_bit: bool = False, **kwargs: Any
arch: Any = "mobilenet_v3_small_page_orientation",
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> OrientationPredictor:
"""Page orientation classification architecture.

Expand All @@ -70,10 +81,11 @@ def page_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments to be passed to the OrientationPredictor

Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, load_in_8_bit, **kwargs)
return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
39 changes: 27 additions & 12 deletions onnxtr/models/detection/models/differentiable_binarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from scipy.special import expit

from ...engine import Engine
from ...engine import Engine, EngineConfig
from ..postprocessor.base import GeneralDetectionPostProcessor

__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
Expand All @@ -33,8 +33,8 @@
"input_shape": (3, 1024, 1024),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_mobilenet_v3_large-1866973f.onnx",
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_mobilenet_v3_large_static_8_bit-51659bb9.onnx",
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large-4987e7bd.onnx",
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large_static_8_bit-535a6f25.onnx",
},
}

Expand All @@ -45,6 +45,7 @@ class DBNet(Engine):
Args:
----
model_path: path or url to onnx model file
engine_cfg: configuration for the inference engine
bin_thresh: threshold for binarization of the output feature map
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
Expand All @@ -54,14 +55,15 @@ class DBNet(Engine):

def __init__(
self,
model_path,
model_path: str,
engine_cfg: EngineConfig = EngineConfig(),
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
super().__init__(url=model_path, **kwargs)
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.cfg = cfg
self.assume_straight_pages = assume_straight_pages
self.postprocessor = GeneralDetectionPostProcessor(
Expand Down Expand Up @@ -91,16 +93,20 @@ def _dbnet(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> DBNet:
# Patch the url
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
# Build the model
return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)
return DBNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)


def db_resnet34(
model_path: str = default_cfgs["db_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any
model_path: str = default_cfgs["db_resnet34"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-34 backbone.
Expand All @@ -115,17 +121,21 @@ def db_resnet34(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture

Returns:
-------
text detection architecture
"""
return _dbnet("db_resnet34", model_path, load_in_8_bit, **kwargs)
return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)


def db_resnet50(
model_path: str = default_cfgs["db_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any
model_path: str = default_cfgs["db_resnet50"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
Expand All @@ -140,17 +150,21 @@ def db_resnet50(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture

Returns:
-------
text detection architecture
"""
return _dbnet("db_resnet50", model_path, load_in_8_bit, **kwargs)
return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)


def db_mobilenet_v3_large(
model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, **kwargs: Any
model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
Expand All @@ -165,10 +179,11 @@ def db_mobilenet_v3_large(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture

Returns:
-------
text detection architecture
"""
return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs)
return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
Loading
Loading