Skip to content

Commit

Permalink
bench + reparam (mindee#1519)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Mar 22, 2024
1 parent 6bcc0c6 commit afb9358
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ For a comprehensive comparison, we have compiled a detailed benchmark on publicl
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 81.78 | 82.47 | 87.29 | 85.54 | 1.0 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_tiny | (1024, 1024, 3) | 13.5 M (8.5M) | | | | | 0.7 (0.4) |
| PyTorch | fast_tiny | (1024, 1024, 3) | 13.5 M (8.5M) | 84.90 | 85.04 | 93.73 | 76.26 | 0.7 (0.4) |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_small | (1024, 1024, 3) | 14.7 M (9.7M) | | | | | 0.7 (0.5) |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class FASTPostProcessor(DetectionPostProcessor):

def __init__(
self,
bin_thresh: float = 0.3,
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class FAST(_FAST, nn.Module):
def __init__(
self,
feat_extractor: IntermediateLayerGetter,
bin_thresh: float = 0.3,
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
dropout_prob: float = 0.1,
pooling_size: int = 4, # different from paper performs better on close text-rich images
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class FAST(_FAST, keras.Model, NestedObject):
def __init__(
self,
feature_extractor: IntermediateLayerGetter,
bin_thresh: float = 0.3,
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
dropout_prob: float = 0.1,
pooling_size: int = 4, # different from paper performs better on close text-rich images
Expand Down
4 changes: 4 additions & 0 deletions doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from doctr.file_utils import is_tf_available, is_torch_available

from .. import detection
from ..detection.fast import reparameterize
from ..preprocessor import PreProcessor
from .predictor import DetectionPredictor

Expand Down Expand Up @@ -51,6 +52,9 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
pretrained_backbone=kwargs.get("pretrained_backbone", True),
assume_straight_pages=assume_straight_pages,
)
# Reparameterize FAST models by default to lower inference latency and memory usage
if isinstance(_model, detection.FAST):
_model = reparameterize(_model)
else:
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
raise ValueError(f"unknown architecture: {type(arch)}")
Expand Down

0 comments on commit afb9358

Please sign in to comment.