From 59179b4223d042cf7b752f9f76dff0edbdd835ce Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 27 Jun 2024 10:59:01 +0200 Subject: [PATCH 1/4] make engine configurable --- .pre-commit-config.yaml | 4 +- README.md | 20 ++++- onnxtr/models/__init__.py | 1 + onnxtr/models/builder.py | 2 +- .../models/classification/models/mobilenet.py | 17 ++-- onnxtr/models/classification/zoo.py | 24 +++-- .../models/differentiable_binarization.py | 39 +++++--- onnxtr/models/detection/models/fast.py | 39 ++++++-- onnxtr/models/detection/models/linknet.py | 33 +++++-- onnxtr/models/detection/zoo.py | 15 +++- onnxtr/models/engine.py | 89 +++++++++++++++---- onnxtr/models/predictor/base.py | 9 +- onnxtr/models/predictor/predictor.py | 4 + onnxtr/models/recognition/models/crnn.py | 33 +++++-- onnxtr/models/recognition/models/master.py | 19 ++-- onnxtr/models/recognition/models/parseq.py | 19 ++-- onnxtr/models/recognition/models/sar.py | 17 ++-- onnxtr/models/recognition/models/vitstr.py | 25 ++++-- onnxtr/models/recognition/zoo.py | 15 ++-- onnxtr/models/zoo.py | 16 ++++ scripts/quantize.py | 30 +++++-- setup.py | 2 +- tests/common/test_engine_cfg.py | 66 ++++++++++++++ tests/common/test_models_zoo.py | 9 +- 24 files changed, 430 insertions(+), 117 deletions(-) create mode 100644 tests/common/test_engine_cfg.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8e66bb4..a105406 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-ast - id: check-yaml @@ -16,7 +16,7 @@ repos: - id: no-commit-to-branch args: ['--branch', 'main'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.2 + rev: v0.4.10 hooks: - id: ruff args: [ --fix ] diff --git a/README.md b/README.md index 458a136..23f5fb0 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ Let's use the default `ocr_predictor` model for an example: ```python from onnxtr.io import DocumentFile -from onnxtr.models import ocr_predictor +from onnxtr.models import ocr_predictor, EngineConfig model = ocr_predictor( det_arch='fast_base', # detection architecture @@ -89,11 +89,15 @@ model = ocr_predictor( detect_language=False, # set to `True` if the language of the pages should be detected (default: False) # DocumentBuilder specific parameters resolve_lines=True, # whether words should be automatically grouped into lines (default: True) - resolve_blocks=True, # whether lines should be automatically grouped into blocks (default: True) + resolve_blocks=False, # whether lines should be automatically grouped into blocks (default: False) paragraph_break=0.035, # relative length of the minimum space separating paragraphs (default: 0.035) # OnnxTR specific parameters # NOTE: 8-Bit quantized models are not available for FAST detection models and can in general lead to poorer accuracy load_in_8_bit=False, # set to `True` to load 8-bit quantized models instead of the full precision onces (default: False) + # Advanced engine configuration options + det_engine_cfg=EngineConfig(), # detection model engine configuration (default: internal predefined configuration) + reco_engine_cfg=EngineConfig(), # recognition model engine configuration (default: internal predefined configuration) + clf_engine_cfg=EngineConfig(), # classification (orientation) model engine configuration (default: internal predefined configuration) ) # PDF doc = DocumentFile.from_pdf("path/to/your/doc.pdf") @@ -103,6 +107,18 @@ result = model(doc) result.show() ``` +
+ Advanced engine configuration options + +Install from source: + +```bash +git clone https://github.com/Lightning-AI/litgpt +cd litgpt +pip install -e '.[all]' +``` +
+ ![Visualization sample](https://github.com/felixdittrich92/OnnxTR/raw/main/docs/images/doctr_example_script.gif) Or even rebuild the original document from its predictions: diff --git a/onnxtr/models/__init__.py b/onnxtr/models/__init__.py index 4e4f327..e612989 100644 --- a/onnxtr/models/__init__.py +++ b/onnxtr/models/__init__.py @@ -1,3 +1,4 @@ +from .engine import EngineConfig from .classification import * from .detection import * from .recognition import * diff --git a/onnxtr/models/builder.py b/onnxtr/models/builder.py index 77eed6c..ee7f182 100644 --- a/onnxtr/models/builder.py +++ b/onnxtr/models/builder.py @@ -31,7 +31,7 @@ class DocumentBuilder(NestedObject): def __init__( self, resolve_lines: bool = True, - resolve_blocks: bool = True, + resolve_blocks: bool = False, paragraph_break: float = 0.035, export_as_straight_boxes: bool = False, ) -> None: diff --git a/onnxtr/models/classification/models/mobilenet.py b/onnxtr/models/classification/models/mobilenet.py index bdfbeda..583f146 100644 --- a/onnxtr/models/classification/models/mobilenet.py +++ b/onnxtr/models/classification/models/mobilenet.py @@ -10,7 +10,7 @@ import numpy as np -from ...engine import Engine +from ...engine import Engine, EngineConfig __all__ = [ "mobilenet_v3_small_crop_orientation", @@ -43,6 +43,7 @@ class MobileNetV3(Engine): Args: ---- model_path: path or url to onnx model file + engine_cfg: configuration for the inference engine cfg: configuration dictionary **kwargs: additional arguments to be passed to `Engine` """ @@ -50,10 +51,11 @@ class MobileNetV3(Engine): def __init__( self, model_path: str, + engine_cfg: EngineConfig = EngineConfig(), cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - super().__init__(url=model_path, **kwargs) + super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs) self.cfg = cfg def __call__( @@ -67,17 +69,19 @@ def _mobilenet_v3( arch: str, model_path: str, load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> MobileNetV3: # Patch the url model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path _cfg = deepcopy(default_cfgs[arch]) - return MobileNetV3(model_path, cfg=_cfg, **kwargs) + return MobileNetV3(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs) def mobilenet_v3_small_crop_orientation( model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"], load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> MobileNetV3: """MobileNetV3-Small architecture as described in @@ -94,18 +98,20 @@ def mobilenet_v3_small_crop_orientation( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the MobileNetV3 architecture Returns: ------- MobileNetV3 """ - return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, **kwargs) + return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs) def mobilenet_v3_small_page_orientation( model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"], load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> MobileNetV3: """MobileNetV3-Small architecture as described in @@ -122,10 +128,11 @@ def mobilenet_v3_small_page_orientation( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the MobileNetV3 architecture Returns: ------- MobileNetV3 """ - return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, **kwargs) + return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/classification/zoo.py b/onnxtr/models/classification/zoo.py index ef91e9e..ad0b50b 100644 --- a/onnxtr/models/classification/zoo.py +++ b/onnxtr/models/classification/zoo.py @@ -5,6 +5,8 @@ from typing import Any, List +from onnxtr.models.engine import EngineConfig + from .. import classification from ..preprocessor import PreProcessor from .predictor import OrientationPredictor @@ -14,12 +16,14 @@ ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"] -def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any) -> OrientationPredictor: +def _orientation_predictor( + arch: str, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any +) -> OrientationPredictor: if arch not in ORIENTATION_ARCHS: raise ValueError(f"unknown architecture '{arch}'") # Load directly classifier from backbone - _model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit) + _model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg) kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) @@ -32,7 +36,10 @@ def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any def crop_orientation_predictor( - arch: Any = "mobilenet_v3_small_crop_orientation", load_in_8_bit: bool = False, **kwargs: Any + arch: Any = "mobilenet_v3_small_crop_orientation", + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> OrientationPredictor: """Crop orientation classification architecture. @@ -46,17 +53,21 @@ def crop_orientation_predictor( ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation') load_in_8_bit: load the 8-bit quantized version of the model + engine_cfg: configuration of inference engine **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: ------- OrientationPredictor """ - return _orientation_predictor(arch, load_in_8_bit, **kwargs) + return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs) def page_orientation_predictor( - arch: Any = "mobilenet_v3_small_page_orientation", load_in_8_bit: bool = False, **kwargs: Any + arch: Any = "mobilenet_v3_small_page_orientation", + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> OrientationPredictor: """Page orientation classification architecture. @@ -70,10 +81,11 @@ def page_orientation_predictor( ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation') load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: ------- OrientationPredictor """ - return _orientation_predictor(arch, load_in_8_bit, **kwargs) + return _orientation_predictor(arch, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/detection/models/differentiable_binarization.py b/onnxtr/models/detection/models/differentiable_binarization.py index 53be535..1747bcb 100644 --- a/onnxtr/models/detection/models/differentiable_binarization.py +++ b/onnxtr/models/detection/models/differentiable_binarization.py @@ -8,7 +8,7 @@ import numpy as np from scipy.special import expit -from ...engine import Engine +from ...engine import Engine, EngineConfig from ..postprocessor.base import GeneralDetectionPostProcessor __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"] @@ -33,8 +33,8 @@ "input_shape": (3, 1024, 1024), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_mobilenet_v3_large-1866973f.onnx", - "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_mobilenet_v3_large_static_8_bit-51659bb9.onnx", + "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large-4987e7bd.onnx", + "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.2.0/db_mobilenet_v3_large_static_8_bit-535a6f25.onnx", }, } @@ -45,6 +45,7 @@ class DBNet(Engine): Args: ---- model_path: path or url to onnx model file + engine_cfg: configuration for the inference engine bin_thresh: threshold for binarization of the output feature map box_thresh: minimal objectness score to consider a box assume_straight_pages: if True, fit straight bounding boxes only @@ -54,14 +55,15 @@ class DBNet(Engine): def __init__( self, - model_path, + model_path: str, + engine_cfg: EngineConfig = EngineConfig(), bin_thresh: float = 0.3, box_thresh: float = 0.1, assume_straight_pages: bool = True, cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - super().__init__(url=model_path, **kwargs) + super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs) self.cfg = cfg self.assume_straight_pages = assume_straight_pages self.postprocessor = GeneralDetectionPostProcessor( @@ -91,16 +93,20 @@ def _dbnet( arch: str, model_path: str, load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> DBNet: # Patch the url model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path # Build the model - return DBNet(model_path, cfg=default_cfgs[arch], **kwargs) + return DBNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs) def db_resnet34( - model_path: str = default_cfgs["db_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["db_resnet34"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> DBNet: """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" `_, using a ResNet-34 backbone. @@ -115,17 +121,21 @@ def db_resnet34( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: ------- text detection architecture """ - return _dbnet("db_resnet34", model_path, load_in_8_bit, **kwargs) + return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs) def db_resnet50( - model_path: str = default_cfgs["db_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["db_resnet50"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> DBNet: """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" `_, using a ResNet-50 backbone. @@ -140,17 +150,21 @@ def db_resnet50( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: ------- text detection architecture """ - return _dbnet("db_resnet50", model_path, load_in_8_bit, **kwargs) + return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs) def db_mobilenet_v3_large( - model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> DBNet: """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" `_, using a MobileNet V3 Large backbone. @@ -165,10 +179,11 @@ def db_mobilenet_v3_large( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: ------- text detection architecture """ - return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs) + return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/detection/models/fast.py b/onnxtr/models/detection/models/fast.py index 30d96c3..d8ee844 100644 --- a/onnxtr/models/detection/models/fast.py +++ b/onnxtr/models/detection/models/fast.py @@ -9,7 +9,7 @@ import numpy as np from scipy.special import expit -from ...engine import Engine +from ...engine import Engine, EngineConfig from ..postprocessor.base import GeneralDetectionPostProcessor __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"] @@ -43,6 +43,7 @@ class FAST(Engine): Args: ---- model_path: path or url to onnx model file + engine_cfg: configuration for the inference engine bin_thresh: threshold for binarization of the output feature map box_thresh: minimal objectness score to consider a box assume_straight_pages: if True, fit straight bounding boxes only @@ -53,13 +54,14 @@ class FAST(Engine): def __init__( self, model_path: str, + engine_cfg: EngineConfig = EngineConfig(), bin_thresh: float = 0.1, box_thresh: float = 0.1, assume_straight_pages: bool = True, cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - super().__init__(url=model_path, **kwargs) + super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs) self.cfg = cfg self.assume_straight_pages = assume_straight_pages @@ -90,15 +92,21 @@ def _fast( arch: str, model_path: str, load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> FAST: if load_in_8_bit: logging.warning("FAST models do not support 8-bit quantization yet. Loading full precision model...") # Build the model - return FAST(model_path, cfg=default_cfgs[arch], **kwargs) + return FAST(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs) -def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST: +def fast_tiny( + model_path: str = default_cfgs["fast_tiny"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, +) -> FAST: """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" `_, using a tiny TextNet backbone. @@ -112,16 +120,22 @@ def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], load_in_8_bit: ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: ------- text detection architecture """ - return _fast("fast_tiny", model_path, load_in_8_bit, **kwargs) + return _fast("fast_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs) -def fast_small(model_path: str = default_cfgs["fast_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST: +def fast_small( + model_path: str = default_cfgs["fast_small"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, +) -> FAST: """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" `_, using a small TextNet backbone. @@ -135,16 +149,22 @@ def fast_small(model_path: str = default_cfgs["fast_small"]["url"], load_in_8_bi ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: ------- text detection architecture """ - return _fast("fast_small", model_path, load_in_8_bit, **kwargs) + return _fast("fast_small", model_path, load_in_8_bit, engine_cfg, **kwargs) -def fast_base(model_path: str = default_cfgs["fast_base"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST: +def fast_base( + model_path: str = default_cfgs["fast_base"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, +) -> FAST: """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" `_, using a base TextNet backbone. @@ -158,10 +178,11 @@ def fast_base(model_path: str = default_cfgs["fast_base"]["url"], load_in_8_bit: ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: ------- text detection architecture """ - return _fast("fast_base", model_path, load_in_8_bit, **kwargs) + return _fast("fast_base", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/detection/models/linknet.py b/onnxtr/models/detection/models/linknet.py index 142be6b..852d1be 100644 --- a/onnxtr/models/detection/models/linknet.py +++ b/onnxtr/models/detection/models/linknet.py @@ -8,7 +8,7 @@ import numpy as np from scipy.special import expit -from ...engine import Engine +from ...engine import Engine, EngineConfig from ..postprocessor.base import GeneralDetectionPostProcessor __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"] @@ -45,6 +45,7 @@ class LinkNet(Engine): Args: ---- model_path: path or url to onnx model file + engine_cfg: configuration for the inference engine bin_thresh: threshold for binarization of the output feature map box_thresh: minimal objectness score to consider a box assume_straight_pages: if True, fit straight bounding boxes only @@ -55,13 +56,14 @@ class LinkNet(Engine): def __init__( self, model_path: str, + engine_cfg: EngineConfig = EngineConfig(), bin_thresh: float = 0.1, box_thresh: float = 0.1, assume_straight_pages: bool = True, cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - super().__init__(url=model_path, **kwargs) + super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs) self.cfg = cfg self.assume_straight_pages = assume_straight_pages @@ -92,16 +94,20 @@ def _linknet( arch: str, model_path: str, load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> LinkNet: # Patch the url model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path # Build the model - return LinkNet(model_path, cfg=default_cfgs[arch], **kwargs) + return LinkNet(model_path, cfg=default_cfgs[arch], engine_cfg=engine_cfg, **kwargs) def linknet_resnet18( - model_path: str = default_cfgs["linknet_resnet18"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["linknet_resnet18"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> LinkNet: """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" `_. @@ -116,17 +122,21 @@ def linknet_resnet18( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the LinkNet architecture Returns: ------- text detection architecture """ - return _linknet("linknet_resnet18", model_path, load_in_8_bit, **kwargs) + return _linknet("linknet_resnet18", model_path, load_in_8_bit, engine_cfg, **kwargs) def linknet_resnet34( - model_path: str = default_cfgs["linknet_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["linknet_resnet34"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> LinkNet: """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" `_. @@ -141,17 +151,21 @@ def linknet_resnet34( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the LinkNet architecture Returns: ------- text detection architecture """ - return _linknet("linknet_resnet34", model_path, load_in_8_bit, **kwargs) + return _linknet("linknet_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs) def linknet_resnet50( - model_path: str = default_cfgs["linknet_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["linknet_resnet50"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> LinkNet: """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" `_. @@ -166,10 +180,11 @@ def linknet_resnet50( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the LinkNet architecture Returns: ------- text detection architecture """ - return _linknet("linknet_resnet50", model_path, load_in_8_bit, **kwargs) + return _linknet("linknet_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/detection/zoo.py b/onnxtr/models/detection/zoo.py index e241a9d..cda0eed 100644 --- a/onnxtr/models/detection/zoo.py +++ b/onnxtr/models/detection/zoo.py @@ -6,6 +6,7 @@ from typing import Any from .. import detection +from ..engine import EngineConfig from ..preprocessor import PreProcessor from .predictor import DetectionPredictor @@ -25,13 +26,19 @@ def _predictor( - arch: Any, assume_straight_pages: bool = True, load_in_8_bit: bool = False, **kwargs: Any + arch: Any, + assume_straight_pages: bool = True, + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> DetectionPredictor: if isinstance(arch, str): if arch not in ARCHS: raise ValueError(f"unknown architecture '{arch}'") - _model = detection.__dict__[arch](assume_straight_pages=assume_straight_pages, load_in_8_bit=load_in_8_bit) + _model = detection.__dict__[arch]( + assume_straight_pages=assume_straight_pages, load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg + ) else: if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)): raise ValueError(f"unknown architecture: {type(arch)}") @@ -53,6 +60,7 @@ def detection_predictor( arch: Any = "fast_base", assume_straight_pages: bool = True, load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> DetectionPredictor: """Text detection architecture. @@ -68,10 +76,11 @@ def detection_predictor( arch: name of the architecture or model itself to use (e.g. 'db_resnet50') assume_straight_pages: If True, fit straight boxes to the page load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: optional keyword arguments passed to the architecture Returns: ------- Detection predictor """ - return _predictor(arch, assume_straight_pages, load_in_8_bit, **kwargs) + return _predictor(arch, assume_straight_pages, load_in_8_bit, engine_cfg=engine_cfg, **kwargs) diff --git a/onnxtr/models/engine.py b/onnxtr/models/engine.py index 7ca8183..d4adcfe 100644 --- a/onnxtr/models/engine.py +++ b/onnxtr/models/engine.py @@ -3,14 +3,79 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, List, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from onnxruntime import ExecutionMode, GraphOptimizationLevel, InferenceSession, SessionOptions +from onnxruntime import ( + ExecutionMode, + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, +) from onnxtr.utils.data import download_from_url from onnxtr.utils.geometry import shape_translate +__all__ = ["EngineConfig"] + + +class EngineConfig: + """Implements a configuration class for the engine of a model + + Args: + ---- + providers: list of providers to use for inference ref.: https://onnxruntime.ai/docs/execution-providers/ + session_options: configuration for the inference session ref.: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions + """ + + def __init__( + self, + providers: Optional[Union[List[Tuple[str, Dict[str, Any]]], List[str]]] = None, + session_options: Optional[SessionOptions] = None, + ): + self._providers = providers or self._init_providers() + self._session_options = session_options or self._init_sess_opts() + + def _init_providers(self) -> List[Tuple[str, Dict[str, Any]]]: + providers: Any = [("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})] + available_providers = get_available_providers() + if "CUDAExecutionProvider" in available_providers and get_device() == "GPU": # pragma: no cover + providers.insert( + 0, + ( + "CUDAExecutionProvider", + { + "device_id": 0, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": True, + }, + ), + ) + return providers + + def _init_sess_opts(self) -> SessionOptions: + session_options = SessionOptions() + session_options.enable_cpu_mem_arena = True + session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL + session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + session_options.intra_op_num_threads = -1 + session_options.inter_op_num_threads = -1 + return session_options + + @property + def providers(self) -> Union[List[Tuple[str, Dict[str, Any]]], List[str]]: + return self._providers + + @property + def session_options(self) -> SessionOptions: + return self._session_options + + def __repr__(self) -> str: + return f"EngineConfig(providers={self.providers}" + class Engine: """Implements an abstract class for the engine of a model @@ -18,16 +83,15 @@ class Engine: Args: ---- url: the url to use to download a model if needed - providers: list of providers to use for inference + engine_cfg: the configuration of the engine **kwargs: additional arguments to be passed to `download_from_url` """ - def __init__( - self, url: str, providers: List[str] = ["CPUExecutionProvider", "CUDAExecutionProvider"], **kwargs: Any - ) -> None: + def __init__(self, url: str, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any) -> None: archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url - session_options = self._init_sess_opts() - self.runtime = InferenceSession(archive_path, providers=providers, sess_options=session_options) + self.session_options = engine_cfg.session_options + self.providers = engine_cfg.providers + self.runtime = InferenceSession(archive_path, providers=self.providers, sess_options=self.session_options) self.runtime_inputs = self.runtime.get_inputs()[0] self.tf_exported = int(self.runtime_inputs.shape[-1]) == 3 self.fixed_batch_size: Union[int, str] = self.runtime_inputs.shape[ @@ -35,15 +99,6 @@ def __init__( ] # mostly possible with tensorflow exported models self.output_name = [output.name for output in self.runtime.get_outputs()] - def _init_sess_opts(self) -> SessionOptions: - session_options = SessionOptions() - session_options.enable_cpu_mem_arena = True - session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL - session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - session_options.intra_op_num_threads = -1 - session_options.inter_op_num_threads = -1 - return session_options - def run(self, inputs: np.ndarray) -> np.ndarray: if self.tf_exported: inputs = shape_translate(inputs, format="BHWC") # sanity check diff --git a/onnxtr/models/predictor/base.py b/onnxtr/models/predictor/base.py index aa0a690..d31deb7 100644 --- a/onnxtr/models/predictor/base.py +++ b/onnxtr/models/predictor/base.py @@ -8,6 +8,7 @@ import numpy as np from onnxtr.models.builder import DocumentBuilder +from onnxtr.models.engine import EngineConfig from onnxtr.utils.geometry import extract_crops, extract_rcrops, rotate_image from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds @@ -34,6 +35,7 @@ class _OCRPredictor: detect_orientation: if True, the estimated general page orientation will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + clf_engine_cfg: configuration of the orientation classification engine **kwargs: keyword args of `DocumentBuilder` """ @@ -48,15 +50,18 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, load_in_8_bit: bool = False, + clf_engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> None: self.assume_straight_pages = assume_straight_pages self.straighten_pages = straighten_pages self.crop_orientation_predictor = ( - None if assume_straight_pages else crop_orientation_predictor(load_in_8_bit=load_in_8_bit) + None + if assume_straight_pages + else crop_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg) ) self.page_orientation_predictor = ( - page_orientation_predictor(load_in_8_bit=load_in_8_bit) + page_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg) if detect_orientation or straighten_pages or not assume_straight_pages else None ) diff --git a/onnxtr/models/predictor/predictor.py b/onnxtr/models/predictor/predictor.py index ab54311..20c8b79 100644 --- a/onnxtr/models/predictor/predictor.py +++ b/onnxtr/models/predictor/predictor.py @@ -10,6 +10,7 @@ from onnxtr.io.elements import Document from onnxtr.models._utils import get_language from onnxtr.models.detection.predictor import DetectionPredictor +from onnxtr.models.engine import EngineConfig from onnxtr.models.recognition.predictor import RecognitionPredictor from onnxtr.utils.geometry import detach_scores from onnxtr.utils.repr import NestedObject @@ -35,6 +36,7 @@ class OCRPredictor(NestedObject, _OCRPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + clf_engine_cfg: configuration of the orientation classification engine **kwargs: keyword args of `DocumentBuilder` """ @@ -50,6 +52,7 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, + clf_engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> None: self.det_predictor = det_predictor @@ -61,6 +64,7 @@ def __init__( preserve_aspect_ratio, symmetric_pad, detect_orientation, + clf_engine_cfg=clf_engine_cfg, **kwargs, ) self.detect_orientation = detect_orientation diff --git a/onnxtr/models/recognition/models/crnn.py b/onnxtr/models/recognition/models/crnn.py index a969de5..3ce0181 100644 --- a/onnxtr/models/recognition/models/crnn.py +++ b/onnxtr/models/recognition/models/crnn.py @@ -12,7 +12,7 @@ from onnxtr.utils import VOCABS -from ...engine import Engine +from ...engine import Engine, EngineConfig from ..core import RecognitionPostProcessor __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] @@ -113,6 +113,7 @@ class CRNN(Engine): ---- model_path: path or url to onnx model file vocab: vocabulary used for encoding + engine_cfg: configuration for the inference engine cfg: configuration dictionary **kwargs: additional arguments to be passed to `Engine` """ @@ -123,10 +124,11 @@ def __init__( self, model_path: str, vocab: str, + engine_cfg: EngineConfig = EngineConfig(), cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - super().__init__(url=model_path, **kwargs) + super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs) self.vocab = vocab self.cfg = cfg self.postprocessor = CRNNPostProcessor(self.vocab) @@ -152,6 +154,7 @@ def _crnn( arch: str, model_path: str, load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> CRNN: kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"]) @@ -163,11 +166,14 @@ def _crnn( model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path # Build the model - return CRNN(model_path, cfg=_cfg, **kwargs) + return CRNN(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs) def crnn_vgg16_bn( - model_path: str = default_cfgs["crnn_vgg16_bn"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["crnn_vgg16_bn"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> CRNN: """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition" `_. @@ -182,17 +188,21 @@ def crnn_vgg16_bn( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the CRNN architecture Returns: ------- text recognition architecture """ - return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, **kwargs) + return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, engine_cfg, **kwargs) def crnn_mobilenet_v3_small( - model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> CRNN: """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition" `_. @@ -207,17 +217,21 @@ def crnn_mobilenet_v3_small( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the CRNN architecture Returns: ------- text recognition architecture """ - return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, **kwargs) + return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, engine_cfg, **kwargs) def crnn_mobilenet_v3_large( - model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> CRNN: """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition" `_. @@ -232,10 +246,11 @@ def crnn_mobilenet_v3_large( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the CRNN architecture Returns: ------- text recognition architecture """ - return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs) + return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/recognition/models/master.py b/onnxtr/models/recognition/models/master.py index 164f42b..10cadbc 100644 --- a/onnxtr/models/recognition/models/master.py +++ b/onnxtr/models/recognition/models/master.py @@ -11,7 +11,7 @@ from onnxtr.utils import VOCABS -from ...engine import Engine +from ...engine import Engine, EngineConfig from ..core import RecognitionPostProcessor __all__ = ["MASTER", "master"] @@ -36,6 +36,7 @@ class MASTER(Engine): ---- model_path: path or url to onnx model file vocab: vocabulary, (without EOS, SOS, PAD) + engine_cfg: configuration for the inference engine cfg: dictionary containing information about the model **kwargs: additional arguments to be passed to `Engine` """ @@ -44,10 +45,11 @@ def __init__( self, model_path: str, vocab: str, + engine_cfg: EngineConfig = EngineConfig(), cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - super().__init__(url=model_path, **kwargs) + super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs) self.vocab = vocab self.cfg = cfg @@ -114,6 +116,7 @@ def _master( arch: str, model_path: str, load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> MASTER: # Patch the config @@ -125,10 +128,15 @@ def _master( # Patch the url model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path - return MASTER(model_path, cfg=_cfg, **kwargs) + return MASTER(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs) -def master(model_path: str = default_cfgs["master"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> MASTER: +def master( + model_path: str = default_cfgs["master"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, +) -> MASTER: """MASTER as described in paper: `_. >>> import numpy as np @@ -141,10 +149,11 @@ def master(model_path: str = default_cfgs["master"]["url"], load_in_8_bit: bool ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keywoard arguments passed to the MASTER architecture Returns: ------- text recognition architecture """ - return _master("master", model_path, load_in_8_bit, **kwargs) + return _master("master", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/recognition/models/parseq.py b/onnxtr/models/recognition/models/parseq.py index f332f45..8ccdec6 100644 --- a/onnxtr/models/recognition/models/parseq.py +++ b/onnxtr/models/recognition/models/parseq.py @@ -11,7 +11,7 @@ from onnxtr.utils import VOCABS -from ...engine import Engine +from ...engine import Engine, EngineConfig from ..core import RecognitionPostProcessor __all__ = ["PARSeq", "parseq"] @@ -35,6 +35,7 @@ class PARSeq(Engine): ---- model_path: path to onnx model file vocab: vocabulary used for encoding + engine_cfg: configuration for the inference engine cfg: dictionary containing information about the model **kwargs: additional arguments to be passed to `Engine` """ @@ -43,10 +44,11 @@ def __init__( self, model_path: str, vocab: str, + engine_cfg: EngineConfig = EngineConfig(), cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - super().__init__(url=model_path, **kwargs) + super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs) self.vocab = vocab self.cfg = cfg self.postprocessor = PARSeqPostProcessor(vocab=self.vocab) @@ -102,6 +104,7 @@ def _parseq( arch: str, model_path: str, load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> PARSeq: # Patch the config @@ -114,10 +117,15 @@ def _parseq( model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path # Build the model - return PARSeq(model_path, cfg=_cfg, **kwargs) + return PARSeq(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs) -def parseq(model_path: str = default_cfgs["parseq"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> PARSeq: +def parseq( + model_path: str = default_cfgs["parseq"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, +) -> PARSeq: """PARSeq architecture from `"Scene Text Recognition with Permuted Autoregressive Sequence Models" `_. @@ -131,10 +139,11 @@ def parseq(model_path: str = default_cfgs["parseq"]["url"], load_in_8_bit: bool ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the PARSeq architecture Returns: ------- text recognition architecture """ - return _parseq("parseq", model_path, load_in_8_bit, **kwargs) + return _parseq("parseq", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/recognition/models/sar.py b/onnxtr/models/recognition/models/sar.py index 975d61d..7ea2415 100644 --- a/onnxtr/models/recognition/models/sar.py +++ b/onnxtr/models/recognition/models/sar.py @@ -11,7 +11,7 @@ from onnxtr.utils import VOCABS -from ...engine import Engine +from ...engine import Engine, EngineConfig from ..core import RecognitionPostProcessor __all__ = ["SAR", "sar_resnet31"] @@ -35,6 +35,7 @@ class SAR(Engine): ---- model_path: path to onnx model file vocab: vocabulary used for encoding + engine_cfg: configuration for the inference engine cfg: dictionary containing information about the model **kwargs: additional arguments to be passed to `Engine` """ @@ -43,10 +44,11 @@ def __init__( self, model_path: str, vocab: str, + engine_cfg: EngineConfig = EngineConfig(), cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - super().__init__(url=model_path, **kwargs) + super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs) self.vocab = vocab self.cfg = cfg self.postprocessor = SARPostProcessor(self.vocab) @@ -101,6 +103,7 @@ def _sar( arch: str, model_path: str, load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> SAR: # Patch the config @@ -113,11 +116,14 @@ def _sar( model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path # Build the model - return SAR(model_path, cfg=_cfg, **kwargs) + return SAR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs) def sar_resnet31( - model_path: str = default_cfgs["sar_resnet31"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["sar_resnet31"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> SAR: """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition" `_. @@ -132,10 +138,11 @@ def sar_resnet31( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the SAR architecture Returns: ------- text recognition architecture """ - return _sar("sar_resnet31", model_path, load_in_8_bit, **kwargs) + return _sar("sar_resnet31", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/recognition/models/vitstr.py b/onnxtr/models/recognition/models/vitstr.py index 508cbe5..ab37d3e 100644 --- a/onnxtr/models/recognition/models/vitstr.py +++ b/onnxtr/models/recognition/models/vitstr.py @@ -11,7 +11,7 @@ from onnxtr.utils import VOCABS -from ...engine import Engine +from ...engine import Engine, EngineConfig from ..core import RecognitionPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -43,6 +43,7 @@ class ViTSTR(Engine): ---- model_path: path to onnx model file vocab: vocabulary used for encoding + engine_cfg: configuration for the inference engine cfg: dictionary containing information about the model **kwargs: additional arguments to be passed to `Engine` """ @@ -51,10 +52,11 @@ def __init__( self, model_path: str, vocab: str, + engine_cfg: EngineConfig = EngineConfig(), cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - super().__init__(url=model_path, **kwargs) + super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs) self.vocab = vocab self.cfg = cfg @@ -112,6 +114,7 @@ def _vitstr( arch: str, model_path: str, load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> ViTSTR: # Patch the config @@ -124,11 +127,14 @@ def _vitstr( model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path # Build the model - return ViTSTR(model_path, cfg=_cfg, **kwargs) + return ViTSTR(model_path, cfg=_cfg, engine_cfg=engine_cfg, **kwargs) def vitstr_small( - model_path: str = default_cfgs["vitstr_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["vitstr_small"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> ViTSTR: """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" `_. @@ -143,17 +149,21 @@ def vitstr_small( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the ViTSTR architecture Returns: ------- text recognition architecture """ - return _vitstr("vitstr_small", model_path, load_in_8_bit, **kwargs) + return _vitstr("vitstr_small", model_path, load_in_8_bit, engine_cfg, **kwargs) def vitstr_base( - model_path: str = default_cfgs["vitstr_base"]["url"], load_in_8_bit: bool = False, **kwargs: Any + model_path: str = default_cfgs["vitstr_base"]["url"], + load_in_8_bit: bool = False, + engine_cfg: EngineConfig = EngineConfig(), + **kwargs: Any, ) -> ViTSTR: """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" `_. @@ -168,10 +178,11 @@ def vitstr_base( ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the ViTSTR architecture Returns: ------- text recognition architecture """ - return _vitstr("vitstr_base", model_path, load_in_8_bit, **kwargs) + return _vitstr("vitstr_base", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/recognition/zoo.py b/onnxtr/models/recognition/zoo.py index 8f3fedf..d237290 100644 --- a/onnxtr/models/recognition/zoo.py +++ b/onnxtr/models/recognition/zoo.py @@ -5,9 +5,9 @@ from typing import Any, List -from onnxtr.models.preprocessor import PreProcessor - from .. import recognition +from ..engine import EngineConfig +from ..preprocessor import PreProcessor from .predictor import RecognitionPredictor __all__ = ["recognition_predictor"] @@ -25,12 +25,14 @@ ] -def _predictor(arch: Any, load_in_8_bit: bool = False, **kwargs: Any) -> RecognitionPredictor: +def _predictor( + arch: Any, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any +) -> RecognitionPredictor: if isinstance(arch, str): if arch not in ARCHS: raise ValueError(f"unknown architecture '{arch}'") - _model = recognition.__dict__[arch](load_in_8_bit=load_in_8_bit) + _model = recognition.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg) else: if not isinstance( arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq) @@ -48,7 +50,7 @@ def _predictor(arch: Any, load_in_8_bit: bool = False, **kwargs: Any) -> Recogni def recognition_predictor( - arch: Any = "crnn_vgg16_bn", load_in_8_bit: bool = False, **kwargs: Any + arch: Any = "crnn_vgg16_bn", load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any ) -> RecognitionPredictor: """Text recognition architecture. @@ -63,10 +65,11 @@ def recognition_predictor( ---- arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn') load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + engine_cfg: configuration of inference engine **kwargs: optional parameters to be passed to the architecture Returns: ------- Recognition predictor """ - return _predictor(arch, load_in_8_bit, **kwargs) + return _predictor(arch, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/zoo.py b/onnxtr/models/zoo.py index ba650c9..f681306 100644 --- a/onnxtr/models/zoo.py +++ b/onnxtr/models/zoo.py @@ -6,6 +6,7 @@ from typing import Any from .detection.zoo import detection_predictor +from .engine import EngineConfig from .predictor import OCRPredictor from .recognition.zoo import recognition_predictor @@ -24,6 +25,9 @@ def _predictor( straighten_pages: bool = False, detect_language: bool = False, load_in_8_bit: bool = False, + det_engine_cfg: EngineConfig = EngineConfig(), + reco_engine_cfg: EngineConfig = EngineConfig(), + clf_engine_cfg: EngineConfig = EngineConfig(), **kwargs, ) -> OCRPredictor: # Detection @@ -34,6 +38,7 @@ def _predictor( preserve_aspect_ratio=preserve_aspect_ratio, symmetric_pad=symmetric_pad, load_in_8_bit=load_in_8_bit, + engine_cfg=det_engine_cfg, ) # Recognition @@ -41,6 +46,7 @@ def _predictor( reco_arch, batch_size=reco_bs, load_in_8_bit=load_in_8_bit, + engine_cfg=reco_engine_cfg, ) return OCRPredictor( @@ -52,6 +58,7 @@ def _predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + clf_engine_cfg=clf_engine_cfg, **kwargs, ) @@ -67,6 +74,9 @@ def ocr_predictor( straighten_pages: bool = False, detect_language: bool = False, load_in_8_bit: bool = False, + det_engine_cfg: EngineConfig = EngineConfig(), + reco_engine_cfg: EngineConfig = EngineConfig(), + clf_engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any, ) -> OCRPredictor: """End-to-end OCR architecture using one model for localization, and another for text recognition. @@ -99,6 +109,9 @@ def ocr_predictor( detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False + det_engine_cfg: configuration of the detection engine + reco_engine_cfg: configuration of the recognition engine + clf_engine_cfg: configuration of the orientation classification engine kwargs: keyword args of `OCRPredictor` Returns: @@ -116,5 +129,8 @@ def ocr_predictor( straighten_pages=straighten_pages, detect_language=detect_language, load_in_8_bit=load_in_8_bit, + det_engine_cfg=det_engine_cfg, + reco_engine_cfg=reco_engine_cfg, + clf_engine_cfg=clf_engine_cfg, **kwargs, ) diff --git a/scripts/quantize.py b/scripts/quantize.py index e40596e..93b6a8b 100644 --- a/scripts/quantize.py +++ b/scripts/quantize.py @@ -140,15 +140,27 @@ def main(): # Turn off model optimization during quantization if "parseq" not in input_model_path: # Skip static quantization for Parseq print("Calibrating and quantizing model static...") - quantize_static( - input_model_path, - static_out_name, - dr, - quant_format=args.quant_format, - weight_type=QuantType.QUInt8, - activation_type=QuantType.QUInt8, - reduce_range=True, - ) + try: + quantize_static( + input_model_path, + static_out_name, + dr, + quant_format=args.quant_format, + weight_type=QuantType.QInt8, + activation_type=QuantType.QUInt8, + reduce_range=True, + ) + except Exception: + print("Error during static quantization --> Change weight_type also to QUInt8") + quantize_static( + input_model_path, + static_out_name, + dr, + quant_format=args.quant_format, + weight_type=QuantType.QUInt8, + activation_type=QuantType.QUInt8, + reduce_range=True, + ) print("benchmarking static int8 model...") benchmark(calibration_dataset_path, static_out_name, task_shape) diff --git a/setup.py b/setup.py index 9ce9d41..e2a55d2 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ from setuptools import setup PKG_NAME = "onnxtr" -VERSION = os.getenv("BUILD_VERSION", "0.2.1a0") +VERSION = os.getenv("BUILD_VERSION", "0.3.0a0") if __name__ == "__main__": diff --git a/tests/common/test_engine_cfg.py b/tests/common/test_engine_cfg.py new file mode 100644 index 0000000..07f5613 --- /dev/null +++ b/tests/common/test_engine_cfg.py @@ -0,0 +1,66 @@ +import numpy as np +import pytest +from onnxruntime import SessionOptions + +from onnxtr import models +from onnxtr.io import Document +from onnxtr.models import EngineConfig, detection, recognition +from onnxtr.models.predictor import OCRPredictor + + +def _test_predictor(predictor): + # Output checks + assert isinstance(predictor, OCRPredictor) + + doc = [np.zeros((1024, 1024, 3), dtype=np.uint8)] + out = predictor(doc) + # Document + assert isinstance(out, Document) + + # The input doc has 1 page + assert len(out.pages) == 1 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + +@pytest.mark.parametrize( + "det_arch, reco_arch", + [[det_arch, reco_arch] for det_arch, reco_arch in zip(detection.zoo.ARCHS, recognition.zoo.ARCHS)], +) +def test_engine_cfg(det_arch, reco_arch): + session_options = SessionOptions() + session_options.enable_cpu_mem_arena = False + engine_cfg = EngineConfig( + providers=["CPUExecutionProvider"], + session_options=session_options, + ) + + assert engine_cfg.__repr__() == "EngineConfig(providers=['CPUExecutionProvider']" + + # Model + predictor = models.ocr_predictor( + det_arch, reco_arch, det_engine_cfg=engine_cfg, reco_engine_cfg=engine_cfg, clf_engine_cfg=engine_cfg + ) + assert predictor.det_predictor.model.providers == ["CPUExecutionProvider"] + assert not predictor.det_predictor.model.session_options.enable_cpu_mem_arena + assert predictor.reco_predictor.model.providers == ["CPUExecutionProvider"] + assert not predictor.reco_predictor.model.session_options.enable_cpu_mem_arena + _test_predictor(predictor) + + # passing model instance directly + det_model = detection.__dict__[det_arch](engine_cfg=engine_cfg) + assert det_model.providers == ["CPUExecutionProvider"] + assert not det_model.session_options.enable_cpu_mem_arena + + reco_model = recognition.__dict__[reco_arch](engine_cfg=engine_cfg) + assert reco_model.providers == ["CPUExecutionProvider"] + assert not reco_model.session_options.enable_cpu_mem_arena + + predictor = models.ocr_predictor(det_model, reco_model) + assert predictor.det_predictor.model.providers == ["CPUExecutionProvider"] + assert not predictor.det_predictor.model.session_options.enable_cpu_mem_arena + assert predictor.reco_predictor.model.providers == ["CPUExecutionProvider"] + assert not predictor.reco_predictor.model.session_options.enable_cpu_mem_arena + _test_predictor(predictor) diff --git a/tests/common/test_models_zoo.py b/tests/common/test_models_zoo.py index 8937242..3edc0c9 100644 --- a/tests/common/test_models_zoo.py +++ b/tests/common/test_models_zoo.py @@ -52,6 +52,8 @@ def test_ocrpredictor(mock_pdf, assume_straight_pages, straighten_pages): straighten_pages=straighten_pages, detect_orientation=True, detect_language=True, + resolve_lines=True, + resolve_blocks=True, ) if assume_straight_pages: @@ -72,8 +74,7 @@ def test_ocrpredictor(mock_pdf, assume_straight_pages, straighten_pages): input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) _ = predictor([input_page]) - orientation = 0 - assert out.pages[0].orientation["value"] == orientation + assert out.pages[0].orientation["value"] in range(-2, 3) assert isinstance(out.pages[0].language["value"], str) assert isinstance(out.render(), str) assert isinstance(out.pages[0].render(), str) @@ -102,6 +103,8 @@ def test_trained_ocr_predictor(mock_payslip): assume_straight_pages=True, straighten_pages=True, preserve_aspect_ratio=False, + resolve_lines=True, + resolve_blocks=True, ) # test hooks predictor.add_hook(_DummyCallback()) @@ -131,6 +134,8 @@ def test_trained_ocr_predictor(mock_payslip): straighten_pages=True, preserve_aspect_ratio=True, symmetric_pad=True, + resolve_lines=True, + resolve_blocks=True, ) out = predictor(doc) From 82d7139824c89d656de91ad0129484351f954825 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 27 Jun 2024 11:04:21 +0200 Subject: [PATCH 2/4] conda file --- .conda/meta.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 38aba5e..db71586 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -1,7 +1,7 @@ {% set pyproject = load_file_data('../pyproject.toml', from_recipe_dir=True) %} {% set project = pyproject.get('project') %} {% set urls = pyproject.get('project', {}).get('urls') %} -{% set version = environ.get('BUILD_VERSION', '0.2.1a0') %} +{% set version = environ.get('BUILD_VERSION', '0.3.0a0') %} package: name: onnxtr From fbb1852c7eb61d35f0e760878bd2310e81276ae2 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 27 Jun 2024 11:15:51 +0200 Subject: [PATCH 3/4] update tests --- tests/common/test_engine_cfg.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/common/test_engine_cfg.py b/tests/common/test_engine_cfg.py index 07f5613..f229444 100644 --- a/tests/common/test_engine_cfg.py +++ b/tests/common/test_engine_cfg.py @@ -64,3 +64,11 @@ def test_engine_cfg(det_arch, reco_arch): assert predictor.reco_predictor.model.providers == ["CPUExecutionProvider"] assert not predictor.reco_predictor.model.session_options.enable_cpu_mem_arena _test_predictor(predictor) + + det_predictor = models.detection_predictor(det_arch, engine_cfg=engine_cfg) + assert det_predictor.model.providers == ["CPUExecutionProvider"] + assert not det_predictor.model.session_options.enable_cpu_mem_arena + + reco_predictor = models.recognition_predictor(reco_arch, engine_cfg=engine_cfg) + assert reco_predictor.model.providers == ["CPUExecutionProvider"] + assert not reco_predictor.model.session_options.enable_cpu_mem_arena From 897888e28f5e43e757fae49bb09f00769f5b99a8 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 27 Jun 2024 11:53:48 +0200 Subject: [PATCH 4/4] update --- README.md | 31 ++++++++++++++++++++++++++----- onnxtr/models/engine.py | 1 + 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 23f5fb0..122eac2 100644 --- a/README.md +++ b/README.md @@ -110,13 +110,34 @@ result.show()
Advanced engine configuration options -Install from source: +You can also define advanced engine configurations for the models / predictors: -```bash -git clone https://github.com/Lightning-AI/litgpt -cd litgpt -pip install -e '.[all]' +```python +from onnxruntime import SessionOptions + +from onnxtr.models import ocr_predictor, EngineConfig + +general_options = SessionOptions() # For configuartion options see: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions +general_options.enable_cpu_mem_arena = False + +# NOTE: The following would force to run only on the GPU if no GPU is available it will raise an error +# List of strings e.g. ["CUDAExecutionProvider", "CPUExecutionProvider"] or a list of tuples with the provider and its options e.g. +# [("CUDAExecutionProvider", {"device_id": 0}), ("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})] +providers = [("CUDAExecutionProvider", {"device_id": 0})] # For available providers see: https://onnxruntime.ai/docs/execution-providers/ + +engine_config = EngineConfig( + session_options=general_options, + providers=providers +) +# We use the default predictor with the custom engine configuration +# NOTE: You can define differnt engine configurations for detection, recognition and classification depending on your needs +predictor = ocr_predictor( + det_engine_cfg=engine_config, + reco_engine_cfg=engine_config, + clf_engine_cfg=engine_config +) ``` +
![Visualization sample](https://github.com/felixdittrich92/OnnxTR/raw/main/docs/images/doctr_example_script.gif) diff --git a/onnxtr/models/engine.py b/onnxtr/models/engine.py index d4adcfe..b8035aa 100644 --- a/onnxtr/models/engine.py +++ b/onnxtr/models/engine.py @@ -88,6 +88,7 @@ class Engine: """ def __init__(self, url: str, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any) -> None: + engine_cfg = engine_cfg or EngineConfig() archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url self.session_options = engine_cfg.session_options self.providers = engine_cfg.providers