diff --git a/.conda/meta.yaml b/.conda/meta.yaml
index 38aba5e..db71586 100644
--- a/.conda/meta.yaml
+++ b/.conda/meta.yaml
@@ -1,7 +1,7 @@
{% set pyproject = load_file_data('../pyproject.toml', from_recipe_dir=True) %}
{% set project = pyproject.get('project') %}
{% set urls = pyproject.get('project', {}).get('urls') %}
-{% set version = environ.get('BUILD_VERSION', '0.2.1a0') %}
+{% set version = environ.get('BUILD_VERSION', '0.3.0a0') %}
package:
name: onnxtr
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8e66bb4..a105406 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.5.0
+ rev: v4.6.0
hooks:
- id: check-ast
- id: check-yaml
@@ -16,7 +16,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.3.2
+ rev: v0.4.10
hooks:
- id: ruff
args: [ --fix ]
diff --git a/README.md b/README.md
index 458a136..122eac2 100644
--- a/README.md
+++ b/README.md
@@ -72,7 +72,7 @@ Let's use the default `ocr_predictor` model for an example:
```python
from onnxtr.io import DocumentFile
-from onnxtr.models import ocr_predictor
+from onnxtr.models import ocr_predictor, EngineConfig
model = ocr_predictor(
det_arch='fast_base', # detection architecture
@@ -89,11 +89,15 @@ model = ocr_predictor(
detect_language=False, # set to `True` if the language of the pages should be detected (default: False)
# DocumentBuilder specific parameters
resolve_lines=True, # whether words should be automatically grouped into lines (default: True)
- resolve_blocks=True, # whether lines should be automatically grouped into blocks (default: True)
+ resolve_blocks=False, # whether lines should be automatically grouped into blocks (default: False)
paragraph_break=0.035, # relative length of the minimum space separating paragraphs (default: 0.035)
# OnnxTR specific parameters
# NOTE: 8-Bit quantized models are not available for FAST detection models and can in general lead to poorer accuracy
load_in_8_bit=False, # set to `True` to load 8-bit quantized models instead of the full precision onces (default: False)
+ # Advanced engine configuration options
+ det_engine_cfg=EngineConfig(), # detection model engine configuration (default: internal predefined configuration)
+ reco_engine_cfg=EngineConfig(), # recognition model engine configuration (default: internal predefined configuration)
+ clf_engine_cfg=EngineConfig(), # classification (orientation) model engine configuration (default: internal predefined configuration)
)
# PDF
doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
@@ -103,6 +107,39 @@ result = model(doc)
result.show()
```
+
+ Advanced engine configuration options
+
+You can also define advanced engine configurations for the models / predictors:
+
+```python
+from onnxruntime import SessionOptions
+
+from onnxtr.models import ocr_predictor, EngineConfig
+
+general_options = SessionOptions() # For configuartion options see: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
+general_options.enable_cpu_mem_arena = False
+
+# NOTE: The following would force to run only on the GPU if no GPU is available it will raise an error
+# List of strings e.g. ["CUDAExecutionProvider", "CPUExecutionProvider"] or a list of tuples with the provider and its options e.g.
+# [("CUDAExecutionProvider", {"device_id": 0}), ("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})]
+providers = [("CUDAExecutionProvider", {"device_id": 0})] # For available providers see: https://onnxruntime.ai/docs/execution-providers/
+
+engine_config = EngineConfig(
+ session_options=general_options,
+ providers=providers
+)
+# We use the default predictor with the custom engine configuration
+# NOTE: You can define differnt engine configurations for detection, recognition and classification depending on your needs
+predictor = ocr_predictor(
+ det_engine_cfg=engine_config,
+ reco_engine_cfg=engine_config,
+ clf_engine_cfg=engine_config
+)
+```
+
+
+

Or even rebuild the original document from its predictions:
diff --git a/onnxtr/models/__init__.py b/onnxtr/models/__init__.py
index 4e4f327..e612989 100644
--- a/onnxtr/models/__init__.py
+++ b/onnxtr/models/__init__.py
@@ -1,3 +1,4 @@
+from .engine import EngineConfig
from .classification import *
from .detection import *
from .recognition import *
diff --git a/onnxtr/models/builder.py b/onnxtr/models/builder.py
index 77eed6c..ee7f182 100644
--- a/onnxtr/models/builder.py
+++ b/onnxtr/models/builder.py
@@ -31,7 +31,7 @@ class DocumentBuilder(NestedObject):
def __init__(
self,
resolve_lines: bool = True,
- resolve_blocks: bool = True,
+ resolve_blocks: bool = False,
paragraph_break: float = 0.035,
export_as_straight_boxes: bool = False,
) -> None:
diff --git a/onnxtr/models/classification/models/mobilenet.py b/onnxtr/models/classification/models/mobilenet.py
index bdfbeda..583f146 100644
--- a/onnxtr/models/classification/models/mobilenet.py
+++ b/onnxtr/models/classification/models/mobilenet.py
@@ -10,7 +10,7 @@
import numpy as np
-from ...engine import Engine
+from ...engine import Engine, EngineConfig
__all__ = [
"mobilenet_v3_small_crop_orientation",
@@ -43,6 +43,7 @@ class MobileNetV3(Engine):
Args:
----
model_path: path or url to onnx model file
+ engine_cfg: configuration for the inference engine
cfg: configuration dictionary
**kwargs: additional arguments to be passed to `Engine`
"""
@@ -50,10 +51,11 @@ class MobileNetV3(Engine):
def __init__(
self,
model_path: str,
+ engine_cfg: EngineConfig = EngineConfig(),
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
- super().__init__(url=model_path, **kwargs)
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.cfg = cfg
def __call__(
@@ -67,17 +69,19 @@ def _mobilenet_v3(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> MobileNetV3:
# Patch the url
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
_cfg = deepcopy(default_cfgs[arch])
- return MobileNetV3(model_path, cfg=_cfg, **kwargs)
+ return MobileNetV3(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
def mobilenet_v3_small_crop_orientation(
model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
@@ -94,18 +98,20 @@ def mobilenet_v3_small_crop_orientation(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
-------
MobileNetV3
"""
- return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, **kwargs)
+ return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
def mobilenet_v3_small_page_orientation(
model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
@@ -122,10 +128,11 @@ def mobilenet_v3_small_page_orientation(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
-------
MobileNetV3
"""
- return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, **kwargs)
+ return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/classification/zoo.py b/onnxtr/models/classification/zoo.py
index ef91e9e..ad0b50b 100644
--- a/onnxtr/models/classification/zoo.py
+++ b/onnxtr/models/classification/zoo.py
@@ -5,6 +5,8 @@
from typing import Any, List
+from onnxtr.models.engine import EngineConfig
+
from .. import classification
from ..preprocessor import PreProcessor
from .predictor import OrientationPredictor
@@ -14,12 +16,14 @@
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
-def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any) -> OrientationPredictor:
+def _orientation_predictor(
+ arch: str, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any
+) -> OrientationPredictor:
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")
# Load directly classifier from backbone
- _model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit)
+ _model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg)
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
@@ -32,7 +36,10 @@ def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any
def crop_orientation_predictor(
- arch: Any = "mobilenet_v3_small_crop_orientation", load_in_8_bit: bool = False, **kwargs: Any
+ arch: Any = "mobilenet_v3_small_crop_orientation",
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> OrientationPredictor:
"""Crop orientation classification architecture.
@@ -46,17 +53,21 @@ def crop_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
load_in_8_bit: load the 8-bit quantized version of the model
+ engine_cfg: configuration of inference engine
**kwargs: keyword arguments to be passed to the OrientationPredictor
Returns:
-------
OrientationPredictor
"""
- return _orientation_predictor(arch, load_in_8_bit, **kwargs)
+ return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
def page_orientation_predictor(
- arch: Any = "mobilenet_v3_small_page_orientation", load_in_8_bit: bool = False, **kwargs: Any
+ arch: Any = "mobilenet_v3_small_page_orientation",
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> OrientationPredictor:
"""Page orientation classification architecture.
@@ -70,10 +81,11 @@ def page_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments to be passed to the OrientationPredictor
Returns:
-------
OrientationPredictor
"""
- return _orientation_predictor(arch, load_in_8_bit, **kwargs)
+ return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/detection/models/differentiable_binarization.py b/onnxtr/models/detection/models/differentiable_binarization.py
index 53be535..1747bcb 100644
--- a/onnxtr/models/detection/models/differentiable_binarization.py
+++ b/onnxtr/models/detection/models/differentiable_binarization.py
@@ -8,7 +8,7 @@
import numpy as np
from scipy.special import expit
-from ...engine import Engine
+from ...engine import Engine, EngineConfig
from ..postprocessor.base import GeneralDetectionPostProcessor
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
@@ -33,8 +33,8 @@
"input_shape": (3, 1024, 1024),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
- "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_mobilenet_v3_large-1866973f.onnx",
- "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_mobilenet_v3_large_static_8_bit-51659bb9.onnx",
+ "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large-4987e7bd.onnx",
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large_static_8_bit-535a6f25.onnx",
},
}
@@ -45,6 +45,7 @@ class DBNet(Engine):
Args:
----
model_path: path or url to onnx model file
+ engine_cfg: configuration for the inference engine
bin_thresh: threshold for binarization of the output feature map
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
@@ -54,14 +55,15 @@ class DBNet(Engine):
def __init__(
self,
- model_path,
+ model_path: str,
+ engine_cfg: EngineConfig = EngineConfig(),
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
- super().__init__(url=model_path, **kwargs)
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.cfg = cfg
self.assume_straight_pages = assume_straight_pages
self.postprocessor = GeneralDetectionPostProcessor(
@@ -91,16 +93,20 @@ def _dbnet(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> DBNet:
# Patch the url
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
# Build the model
- return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)
+ return DBNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
def db_resnet34(
- model_path: str = default_cfgs["db_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["db_resnet34"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
`_, using a ResNet-34 backbone.
@@ -115,17 +121,21 @@ def db_resnet34(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
- return _dbnet("db_resnet34", model_path, load_in_8_bit, **kwargs)
+ return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
def db_resnet50(
- model_path: str = default_cfgs["db_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["db_resnet50"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
`_, using a ResNet-50 backbone.
@@ -140,17 +150,21 @@ def db_resnet50(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
- return _dbnet("db_resnet50", model_path, load_in_8_bit, **kwargs)
+ return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
def db_mobilenet_v3_large(
- model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
`_, using a MobileNet V3 Large backbone.
@@ -165,10 +179,11 @@ def db_mobilenet_v3_large(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
- return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs)
+ return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/detection/models/fast.py b/onnxtr/models/detection/models/fast.py
index 30d96c3..d8ee844 100644
--- a/onnxtr/models/detection/models/fast.py
+++ b/onnxtr/models/detection/models/fast.py
@@ -9,7 +9,7 @@
import numpy as np
from scipy.special import expit
-from ...engine import Engine
+from ...engine import Engine, EngineConfig
from ..postprocessor.base import GeneralDetectionPostProcessor
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"]
@@ -43,6 +43,7 @@ class FAST(Engine):
Args:
----
model_path: path or url to onnx model file
+ engine_cfg: configuration for the inference engine
bin_thresh: threshold for binarization of the output feature map
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
@@ -53,13 +54,14 @@ class FAST(Engine):
def __init__(
self,
model_path: str,
+ engine_cfg: EngineConfig = EngineConfig(),
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
- super().__init__(url=model_path, **kwargs)
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.cfg = cfg
self.assume_straight_pages = assume_straight_pages
@@ -90,15 +92,21 @@ def _fast(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> FAST:
if load_in_8_bit:
logging.warning("FAST models do not support 8-bit quantization yet. Loading full precision model...")
# Build the model
- return FAST(model_path, cfg=default_cfgs[arch], **kwargs)
+ return FAST(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
-def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
+def fast_tiny(
+ model_path: str = default_cfgs["fast_tiny"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
+) -> FAST:
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
`_, using a tiny TextNet backbone.
@@ -112,16 +120,22 @@ def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], load_in_8_bit:
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
- return _fast("fast_tiny", model_path, load_in_8_bit, **kwargs)
+ return _fast("fast_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs)
-def fast_small(model_path: str = default_cfgs["fast_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
+def fast_small(
+ model_path: str = default_cfgs["fast_small"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
+) -> FAST:
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
`_, using a small TextNet backbone.
@@ -135,16 +149,22 @@ def fast_small(model_path: str = default_cfgs["fast_small"]["url"], load_in_8_bi
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
- return _fast("fast_small", model_path, load_in_8_bit, **kwargs)
+ return _fast("fast_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
-def fast_base(model_path: str = default_cfgs["fast_base"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
+def fast_base(
+ model_path: str = default_cfgs["fast_base"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
+) -> FAST:
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
`_, using a base TextNet backbone.
@@ -158,10 +178,11 @@ def fast_base(model_path: str = default_cfgs["fast_base"]["url"], load_in_8_bit:
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
- return _fast("fast_base", model_path, load_in_8_bit, **kwargs)
+ return _fast("fast_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/detection/models/linknet.py b/onnxtr/models/detection/models/linknet.py
index 142be6b..852d1be 100644
--- a/onnxtr/models/detection/models/linknet.py
+++ b/onnxtr/models/detection/models/linknet.py
@@ -8,7 +8,7 @@
import numpy as np
from scipy.special import expit
-from ...engine import Engine
+from ...engine import Engine, EngineConfig
from ..postprocessor.base import GeneralDetectionPostProcessor
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
@@ -45,6 +45,7 @@ class LinkNet(Engine):
Args:
----
model_path: path or url to onnx model file
+ engine_cfg: configuration for the inference engine
bin_thresh: threshold for binarization of the output feature map
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
@@ -55,13 +56,14 @@ class LinkNet(Engine):
def __init__(
self,
model_path: str,
+ engine_cfg: EngineConfig = EngineConfig(),
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
- super().__init__(url=model_path, **kwargs)
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.cfg = cfg
self.assume_straight_pages = assume_straight_pages
@@ -92,16 +94,20 @@ def _linknet(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> LinkNet:
# Patch the url
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
# Build the model
- return LinkNet(model_path, cfg=default_cfgs[arch], **kwargs)
+ return LinkNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs)
def linknet_resnet18(
- model_path: str = default_cfgs["linknet_resnet18"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["linknet_resnet18"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
`_.
@@ -116,17 +122,21 @@ def linknet_resnet18(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the LinkNet architecture
Returns:
-------
text detection architecture
"""
- return _linknet("linknet_resnet18", model_path, load_in_8_bit, **kwargs)
+ return _linknet("linknet_resnet18", model_path, load_in_8_bit, engine_cfg, **kwargs)
def linknet_resnet34(
- model_path: str = default_cfgs["linknet_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["linknet_resnet34"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
`_.
@@ -141,17 +151,21 @@ def linknet_resnet34(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the LinkNet architecture
Returns:
-------
text detection architecture
"""
- return _linknet("linknet_resnet34", model_path, load_in_8_bit, **kwargs)
+ return _linknet("linknet_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
def linknet_resnet50(
- model_path: str = default_cfgs["linknet_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["linknet_resnet50"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
`_.
@@ -166,10 +180,11 @@ def linknet_resnet50(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the LinkNet architecture
Returns:
-------
text detection architecture
"""
- return _linknet("linknet_resnet50", model_path, load_in_8_bit, **kwargs)
+ return _linknet("linknet_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/detection/zoo.py b/onnxtr/models/detection/zoo.py
index e241a9d..cda0eed 100644
--- a/onnxtr/models/detection/zoo.py
+++ b/onnxtr/models/detection/zoo.py
@@ -6,6 +6,7 @@
from typing import Any
from .. import detection
+from ..engine import EngineConfig
from ..preprocessor import PreProcessor
from .predictor import DetectionPredictor
@@ -25,13 +26,19 @@
def _predictor(
- arch: Any, assume_straight_pages: bool = True, load_in_8_bit: bool = False, **kwargs: Any
+ arch: Any,
+ assume_straight_pages: bool = True,
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> DetectionPredictor:
if isinstance(arch, str):
if arch not in ARCHS:
raise ValueError(f"unknown architecture '{arch}'")
- _model = detection.__dict__[arch](assume_straight_pages=assume_straight_pages, load_in_8_bit=load_in_8_bit)
+ _model = detection.__dict__[arch](
+ assume_straight_pages=assume_straight_pages, load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg
+ )
else:
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
raise ValueError(f"unknown architecture: {type(arch)}")
@@ -53,6 +60,7 @@ def detection_predictor(
arch: Any = "fast_base",
assume_straight_pages: bool = True,
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> DetectionPredictor:
"""Text detection architecture.
@@ -68,10 +76,11 @@ def detection_predictor(
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
assume_straight_pages: If True, fit straight boxes to the page
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: optional keyword arguments passed to the architecture
Returns:
-------
Detection predictor
"""
- return _predictor(arch, assume_straight_pages, load_in_8_bit, **kwargs)
+ return _predictor(arch, assume_straight_pages, load_in_8_bit, engine_cfg=engine_cfg, **kwargs)
diff --git a/onnxtr/models/engine.py b/onnxtr/models/engine.py
index 7ca8183..b8035aa 100644
--- a/onnxtr/models/engine.py
+++ b/onnxtr/models/engine.py
@@ -3,14 +3,79 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to for full license details.
-from typing import Any, List, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
-from onnxruntime import ExecutionMode, GraphOptimizationLevel, InferenceSession, SessionOptions
+from onnxruntime import (
+ ExecutionMode,
+ GraphOptimizationLevel,
+ InferenceSession,
+ SessionOptions,
+ get_available_providers,
+ get_device,
+)
from onnxtr.utils.data import download_from_url
from onnxtr.utils.geometry import shape_translate
+__all__ = ["EngineConfig"]
+
+
+class EngineConfig:
+ """Implements a configuration class for the engine of a model
+
+ Args:
+ ----
+ providers: list of providers to use for inference ref.: https://onnxruntime.ai/docs/execution-providers/
+ session_options: configuration for the inference session ref.: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
+ """
+
+ def __init__(
+ self,
+ providers: Optional[Union[List[Tuple[str, Dict[str, Any]]], List[str]]] = None,
+ session_options: Optional[SessionOptions] = None,
+ ):
+ self._providers = providers or self._init_providers()
+ self._session_options = session_options or self._init_sess_opts()
+
+ def _init_providers(self) -> List[Tuple[str, Dict[str, Any]]]:
+ providers: Any = [("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})]
+ available_providers = get_available_providers()
+ if "CUDAExecutionProvider" in available_providers and get_device() == "GPU": # pragma: no cover
+ providers.insert(
+ 0,
+ (
+ "CUDAExecutionProvider",
+ {
+ "device_id": 0,
+ "arena_extend_strategy": "kNextPowerOfTwo",
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
+ "do_copy_in_default_stream": True,
+ },
+ ),
+ )
+ return providers
+
+ def _init_sess_opts(self) -> SessionOptions:
+ session_options = SessionOptions()
+ session_options.enable_cpu_mem_arena = True
+ session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
+ session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
+ session_options.intra_op_num_threads = -1
+ session_options.inter_op_num_threads = -1
+ return session_options
+
+ @property
+ def providers(self) -> Union[List[Tuple[str, Dict[str, Any]]], List[str]]:
+ return self._providers
+
+ @property
+ def session_options(self) -> SessionOptions:
+ return self._session_options
+
+ def __repr__(self) -> str:
+ return f"EngineConfig(providers={self.providers}"
+
class Engine:
"""Implements an abstract class for the engine of a model
@@ -18,16 +83,16 @@ class Engine:
Args:
----
url: the url to use to download a model if needed
- providers: list of providers to use for inference
+ engine_cfg: the configuration of the engine
**kwargs: additional arguments to be passed to `download_from_url`
"""
- def __init__(
- self, url: str, providers: List[str] = ["CPUExecutionProvider", "CUDAExecutionProvider"], **kwargs: Any
- ) -> None:
+ def __init__(self, url: str, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any) -> None:
+ engine_cfg = engine_cfg or EngineConfig()
archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
- session_options = self._init_sess_opts()
- self.runtime = InferenceSession(archive_path, providers=providers, sess_options=session_options)
+ self.session_options = engine_cfg.session_options
+ self.providers = engine_cfg.providers
+ self.runtime = InferenceSession(archive_path, providers=self.providers, sess_options=self.session_options)
self.runtime_inputs = self.runtime.get_inputs()[0]
self.tf_exported = int(self.runtime_inputs.shape[-1]) == 3
self.fixed_batch_size: Union[int, str] = self.runtime_inputs.shape[
@@ -35,15 +100,6 @@ def __init__(
] # mostly possible with tensorflow exported models
self.output_name = [output.name for output in self.runtime.get_outputs()]
- def _init_sess_opts(self) -> SessionOptions:
- session_options = SessionOptions()
- session_options.enable_cpu_mem_arena = True
- session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
- session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
- session_options.intra_op_num_threads = -1
- session_options.inter_op_num_threads = -1
- return session_options
-
def run(self, inputs: np.ndarray) -> np.ndarray:
if self.tf_exported:
inputs = shape_translate(inputs, format="BHWC") # sanity check
diff --git a/onnxtr/models/predictor/base.py b/onnxtr/models/predictor/base.py
index aa0a690..d31deb7 100644
--- a/onnxtr/models/predictor/base.py
+++ b/onnxtr/models/predictor/base.py
@@ -8,6 +8,7 @@
import numpy as np
from onnxtr.models.builder import DocumentBuilder
+from onnxtr.models.engine import EngineConfig
from onnxtr.utils.geometry import extract_crops, extract_rcrops, rotate_image
from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
@@ -34,6 +35,7 @@ class _OCRPredictor:
detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
page. Doing so will slightly deteriorate the overall latency.
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ clf_engine_cfg: configuration of the orientation classification engine
**kwargs: keyword args of `DocumentBuilder`
"""
@@ -48,15 +50,18 @@ def __init__(
symmetric_pad: bool = True,
detect_orientation: bool = False,
load_in_8_bit: bool = False,
+ clf_engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> None:
self.assume_straight_pages = assume_straight_pages
self.straighten_pages = straighten_pages
self.crop_orientation_predictor = (
- None if assume_straight_pages else crop_orientation_predictor(load_in_8_bit=load_in_8_bit)
+ None
+ if assume_straight_pages
+ else crop_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
)
self.page_orientation_predictor = (
- page_orientation_predictor(load_in_8_bit=load_in_8_bit)
+ page_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
if detect_orientation or straighten_pages or not assume_straight_pages
else None
)
diff --git a/onnxtr/models/predictor/predictor.py b/onnxtr/models/predictor/predictor.py
index ab54311..20c8b79 100644
--- a/onnxtr/models/predictor/predictor.py
+++ b/onnxtr/models/predictor/predictor.py
@@ -10,6 +10,7 @@
from onnxtr.io.elements import Document
from onnxtr.models._utils import get_language
from onnxtr.models.detection.predictor import DetectionPredictor
+from onnxtr.models.engine import EngineConfig
from onnxtr.models.recognition.predictor import RecognitionPredictor
from onnxtr.utils.geometry import detach_scores
from onnxtr.utils.repr import NestedObject
@@ -35,6 +36,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
page. Doing so will slightly deteriorate the overall latency.
detect_language: if True, the language prediction will be added to the predictions for each
page. Doing so will slightly deteriorate the overall latency.
+ clf_engine_cfg: configuration of the orientation classification engine
**kwargs: keyword args of `DocumentBuilder`
"""
@@ -50,6 +52,7 @@ def __init__(
symmetric_pad: bool = True,
detect_orientation: bool = False,
detect_language: bool = False,
+ clf_engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> None:
self.det_predictor = det_predictor
@@ -61,6 +64,7 @@ def __init__(
preserve_aspect_ratio,
symmetric_pad,
detect_orientation,
+ clf_engine_cfg=clf_engine_cfg,
**kwargs,
)
self.detect_orientation = detect_orientation
diff --git a/onnxtr/models/recognition/models/crnn.py b/onnxtr/models/recognition/models/crnn.py
index a969de5..3ce0181 100644
--- a/onnxtr/models/recognition/models/crnn.py
+++ b/onnxtr/models/recognition/models/crnn.py
@@ -12,7 +12,7 @@
from onnxtr.utils import VOCABS
-from ...engine import Engine
+from ...engine import Engine, EngineConfig
from ..core import RecognitionPostProcessor
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
@@ -113,6 +113,7 @@ class CRNN(Engine):
----
model_path: path or url to onnx model file
vocab: vocabulary used for encoding
+ engine_cfg: configuration for the inference engine
cfg: configuration dictionary
**kwargs: additional arguments to be passed to `Engine`
"""
@@ -123,10 +124,11 @@ def __init__(
self,
model_path: str,
vocab: str,
+ engine_cfg: EngineConfig = EngineConfig(),
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
- super().__init__(url=model_path, **kwargs)
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.vocab = vocab
self.cfg = cfg
self.postprocessor = CRNNPostProcessor(self.vocab)
@@ -152,6 +154,7 @@ def _crnn(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> CRNN:
kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
@@ -163,11 +166,14 @@ def _crnn(
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
# Build the model
- return CRNN(model_path, cfg=_cfg, **kwargs)
+ return CRNN(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
def crnn_vgg16_bn(
- model_path: str = default_cfgs["crnn_vgg16_bn"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["crnn_vgg16_bn"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> CRNN:
"""CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
Sequence Recognition and Its Application to Scene Text Recognition" `_.
@@ -182,17 +188,21 @@ def crnn_vgg16_bn(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the CRNN architecture
Returns:
-------
text recognition architecture
"""
- return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, **kwargs)
+ return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, engine_cfg, **kwargs)
def crnn_mobilenet_v3_small(
- model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> CRNN:
"""CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
Sequence Recognition and Its Application to Scene Text Recognition" `_.
@@ -207,17 +217,21 @@ def crnn_mobilenet_v3_small(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the CRNN architecture
Returns:
-------
text recognition architecture
"""
- return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, **kwargs)
+ return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
def crnn_mobilenet_v3_large(
- model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> CRNN:
"""CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
Sequence Recognition and Its Application to Scene Text Recognition" `_.
@@ -232,10 +246,11 @@ def crnn_mobilenet_v3_large(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the CRNN architecture
Returns:
-------
text recognition architecture
"""
- return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs)
+ return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/recognition/models/master.py b/onnxtr/models/recognition/models/master.py
index 164f42b..10cadbc 100644
--- a/onnxtr/models/recognition/models/master.py
+++ b/onnxtr/models/recognition/models/master.py
@@ -11,7 +11,7 @@
from onnxtr.utils import VOCABS
-from ...engine import Engine
+from ...engine import Engine, EngineConfig
from ..core import RecognitionPostProcessor
__all__ = ["MASTER", "master"]
@@ -36,6 +36,7 @@ class MASTER(Engine):
----
model_path: path or url to onnx model file
vocab: vocabulary, (without EOS, SOS, PAD)
+ engine_cfg: configuration for the inference engine
cfg: dictionary containing information about the model
**kwargs: additional arguments to be passed to `Engine`
"""
@@ -44,10 +45,11 @@ def __init__(
self,
model_path: str,
vocab: str,
+ engine_cfg: EngineConfig = EngineConfig(),
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
- super().__init__(url=model_path, **kwargs)
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.vocab = vocab
self.cfg = cfg
@@ -114,6 +116,7 @@ def _master(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> MASTER:
# Patch the config
@@ -125,10 +128,15 @@ def _master(
# Patch the url
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
- return MASTER(model_path, cfg=_cfg, **kwargs)
+ return MASTER(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
-def master(model_path: str = default_cfgs["master"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> MASTER:
+def master(
+ model_path: str = default_cfgs["master"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
+) -> MASTER:
"""MASTER as described in paper: `_.
>>> import numpy as np
@@ -141,10 +149,11 @@ def master(model_path: str = default_cfgs["master"]["url"], load_in_8_bit: bool
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keywoard arguments passed to the MASTER architecture
Returns:
-------
text recognition architecture
"""
- return _master("master", model_path, load_in_8_bit, **kwargs)
+ return _master("master", model_path, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/recognition/models/parseq.py b/onnxtr/models/recognition/models/parseq.py
index f332f45..8ccdec6 100644
--- a/onnxtr/models/recognition/models/parseq.py
+++ b/onnxtr/models/recognition/models/parseq.py
@@ -11,7 +11,7 @@
from onnxtr.utils import VOCABS
-from ...engine import Engine
+from ...engine import Engine, EngineConfig
from ..core import RecognitionPostProcessor
__all__ = ["PARSeq", "parseq"]
@@ -35,6 +35,7 @@ class PARSeq(Engine):
----
model_path: path to onnx model file
vocab: vocabulary used for encoding
+ engine_cfg: configuration for the inference engine
cfg: dictionary containing information about the model
**kwargs: additional arguments to be passed to `Engine`
"""
@@ -43,10 +44,11 @@ def __init__(
self,
model_path: str,
vocab: str,
+ engine_cfg: EngineConfig = EngineConfig(),
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
- super().__init__(url=model_path, **kwargs)
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.vocab = vocab
self.cfg = cfg
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
@@ -102,6 +104,7 @@ def _parseq(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> PARSeq:
# Patch the config
@@ -114,10 +117,15 @@ def _parseq(
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
# Build the model
- return PARSeq(model_path, cfg=_cfg, **kwargs)
+ return PARSeq(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
-def parseq(model_path: str = default_cfgs["parseq"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> PARSeq:
+def parseq(
+ model_path: str = default_cfgs["parseq"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
+) -> PARSeq:
"""PARSeq architecture from
`"Scene Text Recognition with Permuted Autoregressive Sequence Models" `_.
@@ -131,10 +139,11 @@ def parseq(model_path: str = default_cfgs["parseq"]["url"], load_in_8_bit: bool
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the PARSeq architecture
Returns:
-------
text recognition architecture
"""
- return _parseq("parseq", model_path, load_in_8_bit, **kwargs)
+ return _parseq("parseq", model_path, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/recognition/models/sar.py b/onnxtr/models/recognition/models/sar.py
index 975d61d..7ea2415 100644
--- a/onnxtr/models/recognition/models/sar.py
+++ b/onnxtr/models/recognition/models/sar.py
@@ -11,7 +11,7 @@
from onnxtr.utils import VOCABS
-from ...engine import Engine
+from ...engine import Engine, EngineConfig
from ..core import RecognitionPostProcessor
__all__ = ["SAR", "sar_resnet31"]
@@ -35,6 +35,7 @@ class SAR(Engine):
----
model_path: path to onnx model file
vocab: vocabulary used for encoding
+ engine_cfg: configuration for the inference engine
cfg: dictionary containing information about the model
**kwargs: additional arguments to be passed to `Engine`
"""
@@ -43,10 +44,11 @@ def __init__(
self,
model_path: str,
vocab: str,
+ engine_cfg: EngineConfig = EngineConfig(),
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
- super().__init__(url=model_path, **kwargs)
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.vocab = vocab
self.cfg = cfg
self.postprocessor = SARPostProcessor(self.vocab)
@@ -101,6 +103,7 @@ def _sar(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> SAR:
# Patch the config
@@ -113,11 +116,14 @@ def _sar(
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
# Build the model
- return SAR(model_path, cfg=_cfg, **kwargs)
+ return SAR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
def sar_resnet31(
- model_path: str = default_cfgs["sar_resnet31"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["sar_resnet31"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> SAR:
"""SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
Baseline for Irregular Text Recognition" `_.
@@ -132,10 +138,11 @@ def sar_resnet31(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the SAR architecture
Returns:
-------
text recognition architecture
"""
- return _sar("sar_resnet31", model_path, load_in_8_bit, **kwargs)
+ return _sar("sar_resnet31", model_path, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/recognition/models/vitstr.py b/onnxtr/models/recognition/models/vitstr.py
index 508cbe5..ab37d3e 100644
--- a/onnxtr/models/recognition/models/vitstr.py
+++ b/onnxtr/models/recognition/models/vitstr.py
@@ -11,7 +11,7 @@
from onnxtr.utils import VOCABS
-from ...engine import Engine
+from ...engine import Engine, EngineConfig
from ..core import RecognitionPostProcessor
__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
@@ -43,6 +43,7 @@ class ViTSTR(Engine):
----
model_path: path to onnx model file
vocab: vocabulary used for encoding
+ engine_cfg: configuration for the inference engine
cfg: dictionary containing information about the model
**kwargs: additional arguments to be passed to `Engine`
"""
@@ -51,10 +52,11 @@ def __init__(
self,
model_path: str,
vocab: str,
+ engine_cfg: EngineConfig = EngineConfig(),
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
- super().__init__(url=model_path, **kwargs)
+ super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
self.vocab = vocab
self.cfg = cfg
@@ -112,6 +114,7 @@ def _vitstr(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> ViTSTR:
# Patch the config
@@ -124,11 +127,14 @@ def _vitstr(
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
# Build the model
- return ViTSTR(model_path, cfg=_cfg, **kwargs)
+ return ViTSTR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs)
def vitstr_small(
- model_path: str = default_cfgs["vitstr_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["vitstr_small"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> ViTSTR:
"""ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
`_.
@@ -143,17 +149,21 @@ def vitstr_small(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the ViTSTR architecture
Returns:
-------
text recognition architecture
"""
- return _vitstr("vitstr_small", model_path, load_in_8_bit, **kwargs)
+ return _vitstr("vitstr_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
def vitstr_base(
- model_path: str = default_cfgs["vitstr_base"]["url"], load_in_8_bit: bool = False, **kwargs: Any
+ model_path: str = default_cfgs["vitstr_base"]["url"],
+ load_in_8_bit: bool = False,
+ engine_cfg: EngineConfig = EngineConfig(),
+ **kwargs: Any,
) -> ViTSTR:
"""ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
`_.
@@ -168,10 +178,11 @@ def vitstr_base(
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration for the inference engine
**kwargs: keyword arguments of the ViTSTR architecture
Returns:
-------
text recognition architecture
"""
- return _vitstr("vitstr_base", model_path, load_in_8_bit, **kwargs)
+ return _vitstr("vitstr_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/recognition/zoo.py b/onnxtr/models/recognition/zoo.py
index 8f3fedf..d237290 100644
--- a/onnxtr/models/recognition/zoo.py
+++ b/onnxtr/models/recognition/zoo.py
@@ -5,9 +5,9 @@
from typing import Any, List
-from onnxtr.models.preprocessor import PreProcessor
-
from .. import recognition
+from ..engine import EngineConfig
+from ..preprocessor import PreProcessor
from .predictor import RecognitionPredictor
__all__ = ["recognition_predictor"]
@@ -25,12 +25,14 @@
]
-def _predictor(arch: Any, load_in_8_bit: bool = False, **kwargs: Any) -> RecognitionPredictor:
+def _predictor(
+ arch: Any, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any
+) -> RecognitionPredictor:
if isinstance(arch, str):
if arch not in ARCHS:
raise ValueError(f"unknown architecture '{arch}'")
- _model = recognition.__dict__[arch](load_in_8_bit=load_in_8_bit)
+ _model = recognition.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg)
else:
if not isinstance(
arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
@@ -48,7 +50,7 @@ def _predictor(arch: Any, load_in_8_bit: bool = False, **kwargs: Any) -> Recogni
def recognition_predictor(
- arch: Any = "crnn_vgg16_bn", load_in_8_bit: bool = False, **kwargs: Any
+ arch: Any = "crnn_vgg16_bn", load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any
) -> RecognitionPredictor:
"""Text recognition architecture.
@@ -63,10 +65,11 @@ def recognition_predictor(
----
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ engine_cfg: configuration of inference engine
**kwargs: optional parameters to be passed to the architecture
Returns:
-------
Recognition predictor
"""
- return _predictor(arch, load_in_8_bit, **kwargs)
+ return _predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
diff --git a/onnxtr/models/zoo.py b/onnxtr/models/zoo.py
index ba650c9..f681306 100644
--- a/onnxtr/models/zoo.py
+++ b/onnxtr/models/zoo.py
@@ -6,6 +6,7 @@
from typing import Any
from .detection.zoo import detection_predictor
+from .engine import EngineConfig
from .predictor import OCRPredictor
from .recognition.zoo import recognition_predictor
@@ -24,6 +25,9 @@ def _predictor(
straighten_pages: bool = False,
detect_language: bool = False,
load_in_8_bit: bool = False,
+ det_engine_cfg: EngineConfig = EngineConfig(),
+ reco_engine_cfg: EngineConfig = EngineConfig(),
+ clf_engine_cfg: EngineConfig = EngineConfig(),
**kwargs,
) -> OCRPredictor:
# Detection
@@ -34,6 +38,7 @@ def _predictor(
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
load_in_8_bit=load_in_8_bit,
+ engine_cfg=det_engine_cfg,
)
# Recognition
@@ -41,6 +46,7 @@ def _predictor(
reco_arch,
batch_size=reco_bs,
load_in_8_bit=load_in_8_bit,
+ engine_cfg=reco_engine_cfg,
)
return OCRPredictor(
@@ -52,6 +58,7 @@ def _predictor(
detect_orientation=detect_orientation,
straighten_pages=straighten_pages,
detect_language=detect_language,
+ clf_engine_cfg=clf_engine_cfg,
**kwargs,
)
@@ -67,6 +74,9 @@ def ocr_predictor(
straighten_pages: bool = False,
detect_language: bool = False,
load_in_8_bit: bool = False,
+ det_engine_cfg: EngineConfig = EngineConfig(),
+ reco_engine_cfg: EngineConfig = EngineConfig(),
+ clf_engine_cfg: EngineConfig = EngineConfig(),
**kwargs: Any,
) -> OCRPredictor:
"""End-to-end OCR architecture using one model for localization, and another for text recognition.
@@ -99,6 +109,9 @@ def ocr_predictor(
detect_language: if True, the language prediction will be added to the predictions for each
page. Doing so will slightly deteriorate the overall latency.
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
+ det_engine_cfg: configuration of the detection engine
+ reco_engine_cfg: configuration of the recognition engine
+ clf_engine_cfg: configuration of the orientation classification engine
kwargs: keyword args of `OCRPredictor`
Returns:
@@ -116,5 +129,8 @@ def ocr_predictor(
straighten_pages=straighten_pages,
detect_language=detect_language,
load_in_8_bit=load_in_8_bit,
+ det_engine_cfg=det_engine_cfg,
+ reco_engine_cfg=reco_engine_cfg,
+ clf_engine_cfg=clf_engine_cfg,
**kwargs,
)
diff --git a/scripts/quantize.py b/scripts/quantize.py
index e40596e..93b6a8b 100644
--- a/scripts/quantize.py
+++ b/scripts/quantize.py
@@ -140,15 +140,27 @@ def main():
# Turn off model optimization during quantization
if "parseq" not in input_model_path: # Skip static quantization for Parseq
print("Calibrating and quantizing model static...")
- quantize_static(
- input_model_path,
- static_out_name,
- dr,
- quant_format=args.quant_format,
- weight_type=QuantType.QUInt8,
- activation_type=QuantType.QUInt8,
- reduce_range=True,
- )
+ try:
+ quantize_static(
+ input_model_path,
+ static_out_name,
+ dr,
+ quant_format=args.quant_format,
+ weight_type=QuantType.QInt8,
+ activation_type=QuantType.QUInt8,
+ reduce_range=True,
+ )
+ except Exception:
+ print("Error during static quantization --> Change weight_type also to QUInt8")
+ quantize_static(
+ input_model_path,
+ static_out_name,
+ dr,
+ quant_format=args.quant_format,
+ weight_type=QuantType.QUInt8,
+ activation_type=QuantType.QUInt8,
+ reduce_range=True,
+ )
print("benchmarking static int8 model...")
benchmark(calibration_dataset_path, static_out_name, task_shape)
diff --git a/setup.py b/setup.py
index 9ce9d41..e2a55d2 100644
--- a/setup.py
+++ b/setup.py
@@ -9,7 +9,7 @@
from setuptools import setup
PKG_NAME = "onnxtr"
-VERSION = os.getenv("BUILD_VERSION", "0.2.1a0")
+VERSION = os.getenv("BUILD_VERSION", "0.3.0a0")
if __name__ == "__main__":
diff --git a/tests/common/test_engine_cfg.py b/tests/common/test_engine_cfg.py
new file mode 100644
index 0000000..f229444
--- /dev/null
+++ b/tests/common/test_engine_cfg.py
@@ -0,0 +1,74 @@
+import numpy as np
+import pytest
+from onnxruntime import SessionOptions
+
+from onnxtr import models
+from onnxtr.io import Document
+from onnxtr.models import EngineConfig, detection, recognition
+from onnxtr.models.predictor import OCRPredictor
+
+
+def _test_predictor(predictor):
+ # Output checks
+ assert isinstance(predictor, OCRPredictor)
+
+ doc = [np.zeros((1024, 1024, 3), dtype=np.uint8)]
+ out = predictor(doc)
+ # Document
+ assert isinstance(out, Document)
+
+ # The input doc has 1 page
+ assert len(out.pages) == 1
+ # Dimension check
+ with pytest.raises(ValueError):
+ input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8)
+ _ = predictor([input_page])
+
+
+@pytest.mark.parametrize(
+ "det_arch, reco_arch",
+ [[det_arch, reco_arch] for det_arch, reco_arch in zip(detection.zoo.ARCHS, recognition.zoo.ARCHS)],
+)
+def test_engine_cfg(det_arch, reco_arch):
+ session_options = SessionOptions()
+ session_options.enable_cpu_mem_arena = False
+ engine_cfg = EngineConfig(
+ providers=["CPUExecutionProvider"],
+ session_options=session_options,
+ )
+
+ assert engine_cfg.__repr__() == "EngineConfig(providers=['CPUExecutionProvider']"
+
+ # Model
+ predictor = models.ocr_predictor(
+ det_arch, reco_arch, det_engine_cfg=engine_cfg, reco_engine_cfg=engine_cfg, clf_engine_cfg=engine_cfg
+ )
+ assert predictor.det_predictor.model.providers == ["CPUExecutionProvider"]
+ assert not predictor.det_predictor.model.session_options.enable_cpu_mem_arena
+ assert predictor.reco_predictor.model.providers == ["CPUExecutionProvider"]
+ assert not predictor.reco_predictor.model.session_options.enable_cpu_mem_arena
+ _test_predictor(predictor)
+
+ # passing model instance directly
+ det_model = detection.__dict__[det_arch](engine_cfg=engine_cfg)
+ assert det_model.providers == ["CPUExecutionProvider"]
+ assert not det_model.session_options.enable_cpu_mem_arena
+
+ reco_model = recognition.__dict__[reco_arch](engine_cfg=engine_cfg)
+ assert reco_model.providers == ["CPUExecutionProvider"]
+ assert not reco_model.session_options.enable_cpu_mem_arena
+
+ predictor = models.ocr_predictor(det_model, reco_model)
+ assert predictor.det_predictor.model.providers == ["CPUExecutionProvider"]
+ assert not predictor.det_predictor.model.session_options.enable_cpu_mem_arena
+ assert predictor.reco_predictor.model.providers == ["CPUExecutionProvider"]
+ assert not predictor.reco_predictor.model.session_options.enable_cpu_mem_arena
+ _test_predictor(predictor)
+
+ det_predictor = models.detection_predictor(det_arch, engine_cfg=engine_cfg)
+ assert det_predictor.model.providers == ["CPUExecutionProvider"]
+ assert not det_predictor.model.session_options.enable_cpu_mem_arena
+
+ reco_predictor = models.recognition_predictor(reco_arch, engine_cfg=engine_cfg)
+ assert reco_predictor.model.providers == ["CPUExecutionProvider"]
+ assert not reco_predictor.model.session_options.enable_cpu_mem_arena
diff --git a/tests/common/test_models_zoo.py b/tests/common/test_models_zoo.py
index 8937242..3edc0c9 100644
--- a/tests/common/test_models_zoo.py
+++ b/tests/common/test_models_zoo.py
@@ -52,6 +52,8 @@ def test_ocrpredictor(mock_pdf, assume_straight_pages, straighten_pages):
straighten_pages=straighten_pages,
detect_orientation=True,
detect_language=True,
+ resolve_lines=True,
+ resolve_blocks=True,
)
if assume_straight_pages:
@@ -72,8 +74,7 @@ def test_ocrpredictor(mock_pdf, assume_straight_pages, straighten_pages):
input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8)
_ = predictor([input_page])
- orientation = 0
- assert out.pages[0].orientation["value"] == orientation
+ assert out.pages[0].orientation["value"] in range(-2, 3)
assert isinstance(out.pages[0].language["value"], str)
assert isinstance(out.render(), str)
assert isinstance(out.pages[0].render(), str)
@@ -102,6 +103,8 @@ def test_trained_ocr_predictor(mock_payslip):
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=False,
+ resolve_lines=True,
+ resolve_blocks=True,
)
# test hooks
predictor.add_hook(_DummyCallback())
@@ -131,6 +134,8 @@ def test_trained_ocr_predictor(mock_payslip):
straighten_pages=True,
preserve_aspect_ratio=True,
symmetric_pad=True,
+ resolve_lines=True,
+ resolve_blocks=True,
)
out = predictor(doc)