Skip to content

Commit

Permalink
[prototype] Extend detection result customization (mindee#1449)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Feb 8, 2024
1 parent 04aa84c commit 3811400
Show file tree
Hide file tree
Showing 18 changed files with 128 additions and 8 deletions.
5 changes: 4 additions & 1 deletion demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def main(det_archs, reco_archs):
# Binarization threshold
bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
st.sidebar.write("\n")
# Box threshold
box_thresh = st.sidebar.slider("Box threshold", min_value=0.1, max_value=0.9, value=0.1, step=0.1)
st.sidebar.write("\n")

if st.sidebar.button("Analyze page"):
if uploaded_file is None:
Expand All @@ -86,7 +89,7 @@ def main(det_archs, reco_archs):
else:
with st.spinner("Loading model..."):
predictor = load_predictor(
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, box_thresh, forward_device
)

with st.spinner("Analyzing..."):
Expand Down
3 changes: 3 additions & 0 deletions demo/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def load_predictor(
assume_straight_pages: bool,
straighten_pages: bool,
bin_thresh: float,
box_thresh: float,
device: torch.device,
) -> OCRPredictor:
"""Load a predictor from doctr.models
Expand All @@ -46,6 +47,7 @@ def load_predictor(
assume_straight_pages: whether to assume straight pages or not
straighten_pages: whether to straighten rotated pages or not
bin_thresh: binarization threshold for the segmentation map
box_thresh: minimal objectness score to consider a box
device: torch.device, the device to load the predictor on
Returns:
Expand All @@ -62,6 +64,7 @@ def load_predictor(
detect_orientation=not assume_straight_pages,
).to(device)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
return predictor


Expand Down
3 changes: 3 additions & 0 deletions demo/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_predictor(
assume_straight_pages: bool,
straighten_pages: bool,
bin_thresh: float,
box_thresh: float,
device: tf.device,
) -> OCRPredictor:
"""Load a predictor from doctr.models
Expand All @@ -45,6 +46,7 @@ def load_predictor(
assume_straight_pages: whether to assume straight pages or not
straighten_pages: whether to straighten rotated pages or not
bin_thresh: binarization threshold for the segmentation map
box_thresh: threshold for the detection boxes
device: tf.device, the device to load the predictor on
Returns:
Expand All @@ -62,6 +64,7 @@ def load_predictor(
detect_orientation=not assume_straight_pages,
)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
return predictor


Expand Down
47 changes: 47 additions & 0 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,50 @@ For reference, here is a sample XML byte string output:
</div>
</body>
</html>
Advanced options
^^^^^^^^^^^^^^^^
We provide a few advanced options to customize the behavior of the predictor to your needs:

* Modify the binarization threshold for the detection model.
* Modify the box threshold for the detection model.

This is useful to detect (possible less) text regions more accurately with a higher threshold, or to detect more text regions with a lower threshold.


.. code:: python3
import numpy as np
from doctr.models import ocr_predictor
predictor = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True)
# Modify the binarization threshold and the box threshold
predictor.det_predictor.model.postprocessor.bin_thresh = 0.5
predictor.det_predictor.model.postprocessor.box_thresh = 0.2
input_page = (255 * np.random.rand(800, 600, 3)).astype(np.uint8)
out = predictor([input_page])
* Add a hook to the `ocr_predictor` to manipulate the location predictions before the crops are passed to the recognition model.

.. code:: python3
from doctr.model import ocr_predictor
class CustomHook:
def __call__(self, loc_preds):
# Manipulate the location predictions here
# 1. The outpout structure needs to be the same as the input location predictions
# 2. Be aware that the coordinates are relative and needs to be between 0 and 1
return loc_preds
my_hook = CustomHook()
predictor = ocr_predictor(pretrained=True)
# Add a hook in the middle of the pipeline
predictor.add_hook(my_hook)
# You can also add multiple hooks which will be executed sequentially
for hook in [my_hook, my_hook, my_hook]:
predictor.add_hook(hook)
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class DBNet(_DBNet, nn.Module):
feature extractor: the backbone serving as feature extractor
head_chans: the number of channels in the head
deform_conv: whether to use deformable convolution
bin_thresh: threshold for binarization
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
exportable: onnx exportable returns only logits
cfg: the configuration dict of the model
Expand All @@ -112,6 +114,7 @@ def __init__(
head_chans: int = 256,
deform_conv: bool = False,
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -160,7 +163,9 @@ def __init__(
nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2),
)

self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)
self.postprocessor = DBPostProcessor(
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
)

for n, m in self.named_modules():
# Don't override the initialization of the backbone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class DBNet(_DBNet, keras.Model, NestedObject):
----
feature extractor: the backbone serving as feature extractor
fpn_channels: number of channels each extracted feature maps is mapped to
bin_thresh: threshold for binarization
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
exportable: onnx exportable returns only logits
cfg: the configuration dict of the model
Expand All @@ -125,6 +127,7 @@ def __init__(
feature_extractor: IntermediateLayerGetter,
fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -159,7 +162,9 @@ def __init__(
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
])

self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)
self.postprocessor = DBPostProcessor(
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
)

def compute_loss(
self,
Expand Down
5 changes: 4 additions & 1 deletion doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class LinkNet(nn.Module, _LinkNet):
Args:
----
feature extractor: the backbone serving as feature extractor
bin_thresh: threshold for binarization of the output feature map
box_thresh: minimal objectness score to consider a box
head_chans: number of channels in the head layers
assume_straight_pages: if True, fit straight bounding boxes only
exportable: onnx exportable returns only logits
Expand All @@ -102,6 +104,7 @@ def __init__(
self,
feat_extractor: IntermediateLayerGetter,
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
head_chans: int = 32,
assume_straight_pages: bool = True,
exportable: bool = False,
Expand Down Expand Up @@ -142,7 +145,7 @@ def __init__(
)

self.postprocessor = LinkNetPostProcessor(
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
)

for n, m in self.named_modules():
Expand Down
7 changes: 6 additions & 1 deletion doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class LinkNet(_LinkNet, keras.Model):
----
feature extractor: the backbone serving as feature extractor
fpn_channels: number of channels each extracted feature maps is mapped to
bin_thresh: threshold for binarization of the output feature map
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
exportable: onnx exportable returns only logits
cfg: the configuration dict of the model
Expand All @@ -111,6 +113,7 @@ def __init__(
feat_extractor: IntermediateLayerGetter,
fpn_channels: int = 64,
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -152,7 +155,9 @@ def __init__(
),
])

self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)
self.postprocessor = LinkNetPostProcessor(
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
)

def compute_loss(
self,
Expand Down
4 changes: 4 additions & 0 deletions doctr/models/kie_predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def forward(
# Rectify crops if aspect ratio
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}

# Apply hooks to loc_preds if any
for hook in self.hooks:
dict_loc_preds = hook(dict_loc_preds)

# Crop images
crops = {}
for class_name in dict_loc_preds.keys():
Expand Down
4 changes: 4 additions & 0 deletions doctr/models/kie_predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def __call__(
# Rectify crops if aspect ratio
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}

# Apply hooks to loc_preds if any
for hook in self.hooks:
dict_loc_preds = hook(dict_loc_preds)

# Crop images
crops = {}
for class_name in dict_loc_preds.keys():
Expand Down
12 changes: 11 additions & 1 deletion doctr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -48,6 +48,7 @@ def __init__(
self.doc_builder = DocumentBuilder(**kwargs)
self.preserve_aspect_ratio = preserve_aspect_ratio
self.symmetric_pad = symmetric_pad
self.hooks: List[Callable] = []

@staticmethod
def _generate_crops(
Expand Down Expand Up @@ -149,3 +150,12 @@ def _process_predictions(
_idx += page_boxes.shape[0]

return loc_preds, text_preds

def add_hook(self, hook: Callable) -> None:
"""Add a hook to the predictor
Args:
----
hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
"""
self.hooks.append(hook)
4 changes: 4 additions & 0 deletions doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def forward(
# Rectify crops if aspect ratio
loc_preds = self._remove_padding(pages, loc_preds)

# Apply hooks to loc_preds if any
for hook in self.hooks:
loc_preds = hook(loc_preds)

# Crop images
crops, loc_preds = self._prepare_crops(
pages,
Expand Down
4 changes: 4 additions & 0 deletions doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def __call__(
# Rectify crops if aspect ratio
loc_preds = self._remove_padding(pages, loc_preds)

# Apply hooks to loc_preds if any
for hook in self.hooks:
loc_preds = hook(loc_preds)

# Crop images
crops, loc_preds = self._prepare_crops(
pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
Expand Down
1 change: 0 additions & 1 deletion doctr/models/preprocessor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
batch_size: int,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
fp16: bool = False,
**kwargs: Any,
) -> None:
super().__init__()
Expand Down
1 change: 0 additions & 1 deletion doctr/models/preprocessor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
batch_size: int,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
fp16: bool = False,
**kwargs: Any,
) -> None:
self.batch_size = batch_size
Expand Down
2 changes: 2 additions & 0 deletions scripts/detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def main(args):
detection_model = detection.__dict__[args.detection](
pretrained=True,
bin_thresh=args.bin_thresh,
box_thresh=args.box_thresh,
)
model = ocr_predictor(detection_model, args.recognition, pretrained=True)
path = Path(args.path)
Expand All @@ -86,6 +87,7 @@ def parse_args():
parser.add_argument("path", type=str, help="Path to process: PDF, image, directory")
parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis")
parser.add_argument("--bin-thresh", type=float, default=0.3, help="Binarization threshold for the detection model.")
parser.add_argument("--box-thresh", type=float, default=0.1, help="Threshold for the detection boxes.")
parser.add_argument(
"--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis"
)
Expand Down
10 changes: 10 additions & 0 deletions tests/pytorch/test_models_zoo_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from doctr.models.recognition.zoo import recognition_predictor


# Create a dummy callback
class _DummyCallback:
def __call__(self, loc_preds):
return loc_preds


@pytest.mark.parametrize(
"assume_straight_pages, straighten_pages",
[
Expand Down Expand Up @@ -121,6 +127,8 @@ def test_trained_ocr_predictor(mock_payslip):
preserve_aspect_ratio=True,
symmetric_pad=True,
)
# test hooks
predictor.add_hook(_DummyCallback())

out = predictor(doc)

Expand Down Expand Up @@ -204,6 +212,8 @@ def test_trained_kie_predictor(mock_payslip):
straighten_pages=True,
preserve_aspect_ratio=False,
)
# test hooks
predictor.add_hook(_DummyCallback())

out = predictor(doc)

Expand Down
10 changes: 10 additions & 0 deletions tests/tensorflow/test_models_zoo_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from doctr.utils.repr import NestedObject


# Create a dummy callback
class _DummyCallback:
def __call__(self, loc_preds):
return loc_preds


@pytest.mark.parametrize(
"assume_straight_pages, straighten_pages",
[
Expand Down Expand Up @@ -92,6 +98,8 @@ def test_trained_ocr_predictor(mock_payslip):
straighten_pages=True,
preserve_aspect_ratio=False,
)
# test hooks
predictor.add_hook(_DummyCallback())

out = predictor(doc)

Expand Down Expand Up @@ -202,6 +210,8 @@ def test_trained_kie_predictor(mock_payslip):
straighten_pages=True,
preserve_aspect_ratio=False,
)
# test hooks
predictor.add_hook(_DummyCallback())

out = predictor(doc)

Expand Down

0 comments on commit 3811400

Please sign in to comment.