From f807e97f7abf5047c31f2c7f0c3d48c2c66bd6d0 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Wed, 28 Feb 2024 10:53:34 +0100 Subject: [PATCH] [Docs] add PyTorch / TensorFlow benchmarks (#1321) --- docs/source/using_doctr/using_models.rst | 85 ++++++++++-------------- 1 file changed, 35 insertions(+), 50 deletions(-) diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index e906338f5..208e0956b 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -28,11 +28,6 @@ The following architectures are currently supported: * :py:meth:`db_resnet50 ` * :py:meth:`db_mobilenet_v3_large ` -We also provide 2 models working with any kind of rotated documents: - -* :py:meth:`linknet_resnet18_rotation ` (TensorFlow) -* :py:meth:`db_resnet50_rotation ` (PyTorch) - For a comprehensive comparison, we have compiled a detailed benchmark on publicly available datasets: @@ -41,31 +36,27 @@ For a comprehensive comparison, we have compiled a detailed benchmark on publicl +================+=================================+=================+==============+============+===============+============+===============+====================+ | **Backend** | **Architecture** | **Input shape** | **# params** | **Recall** | **Precision** | **Recall** | **Precision** | **sec/it (B: 1)** | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| TensorFlow | db_resnet50 | (1024, 1024, 3) | 25.2 M | 81.22 | 86.66 | 92.46 | 89.62 | 1.2 | -+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| Tensorflow | db_mobilenet_v3_large | (1024, 1024, 3) | 4.2 M | 78.27 | 82.77 | 80.99 | 66.57 | 0.5 | -+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| TensorFlow | linknet_resnet18 | (1024, 1024, 3) | 11.5 M | 78.23 | 83.77 | 82.88 | 82.42 | 0.7 | +| TensorFlow | db_resnet50 | (1024, 1024, 3) | 25.2 M | 84.39 | 85.86 | 93.70 | 83.24 | 1.2 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| Tensorflow | linknet_resnet18_rotation | (1024, 1024, 3) | 11.5 M | 81.12 | 82.13 | 83.55 | 80.14 | 0.6 | +| TensorFlow | db_mobilenet_v3_large | (1024, 1024, 3) | 4.2 M | 80.29 | 70.90 | 84.70 | 67.76 | 0.5 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| TensorFlow | linknet_resnet34 | (1024, 1024, 3) | 21.6 M | 82.14 | 87.64 | 85.55 | 86.02 | 0.8 | +| TensorFlow | linknet_resnet18 | (1024, 1024, 3) | 11.5 M | 81.37 | 84.08 | 85.71 | 83.70 | 0.7 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| Tensorflow | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 79.00 | 84.79 | 85.89 | 65.75 | 1.1 | +| TensorFlow | linknet_resnet34 | (1024, 1024, 3) | 21.6 M | 82.20 | 85.49 | 87.63 | 87.17 | 0.8 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | db_resnet34 | (1024, 1024, 3) | 22.4 M | | | | | | +| TensorFlow | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 80.70 | 83.51 | 86.46 | 84.94 | 1.1 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | db_resnet50 | (1024, 1024, 3) | 25.4 M | 79.17 | 86.31 | 92.96 | 91.23 | 1.1 | +| PyTorch | db_resnet34 | (1024, 1024, 3) | 22.4 M | 82.76 | 76.75 | 89.20 | 71.74 | 0.8 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | db_resnet50_rotation | (1024, 1024, 3) | 25.4 M | 83.30 | 91.07 | 91.63 | 90.53 | 1.6 | +| PyTorch | db_resnet50 | (1024, 1024, 3) | 25.4 M | 83.56 | 86.68 | 92.61 | 86.39 | 1.1 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | db_mobilenet_v3_large | (1024, 1024, 3) | 4.2 M | 80.06 | 84.12 | 80.51 | 66.51 | 0.5 | +| PyTorch | db_mobilenet_v3_large | (1024, 1024, 3) | 4.2 M | 83.41 | 84.00 | 86.70 | 79.38 | 0.5 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | linknet_resnet18 | (1024, 1024, 3) | 11.5 M | | | | | | +| PyTorch | linknet_resnet18 | (1024, 1024, 3) | 11.5 M | 81.64 | 85.52 | 88.92 | 82.74 | 0.6 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | linknet_resnet34 | (1024, 1024, 3) | 21.6 M | | | | | | +| PyTorch | linknet_resnet34 | (1024, 1024, 3) | 21.6 M | 81.62 | 82.95 | 86.26 | 81.06 | 0.7 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | | | | | | +| PyTorch | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 81.78 | 82.47 | 87.29 | 85.54 | 1.0 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ @@ -101,9 +92,7 @@ For instance, this snippet will instantiates a detection predictor able to detec .. code:: python3 from doctr.models import detection_predictor - predictor = detection_predictor('db_resnet50_rotation', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True) - -NB: for the moment, `db_resnet50_rotation` is pretrained in Pytorch only and `linknet_resnet18_rotation` in Tensorflow only. + predictor = detection_predictor('db_resnet50', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True) Text Recognition @@ -137,15 +126,15 @@ For a comprehensive comparison, we have compiled a detailed benchmark on publicl +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ | TensorFlow | crnn_vgg16_bn | (32, 128, 3) | 15.8 M | 88.12 | 88.85 | 94.68 | 95.10 | 0.9 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| Tensorflow | crnn_mobilenet_v3_small | (32, 128, 3) | 2.1 M | 86.88 | 87.61 | 92.28 | 92.73 | 0.25 | +| TensorFlow | crnn_mobilenet_v3_small | (32, 128, 3) | 2.1 M | 86.88 | 87.61 | 92.28 | 92.73 | 0.25 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ | TensorFlow | crnn_mobilenet_v3_large | (32, 128, 3) | 4.5 M | 87.44 | 88.12 | 94.14 | 94.55 | 0.34 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| Tensorflow | master | (32, 128, 3) | 58.8 M | 87.44 | 88.21 | 93.83 | 94.25 | 22.3 | +| TensorFlow | master | (32, 128, 3) | 58.8 M | 87.44 | 88.21 | 93.83 | 94.25 | 22.3 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ | TensorFlow | sar_resnet31 | (32, 128, 3) | 57.2 M | 87.67 | 88.48 | 94.21 | 94.66 | 7.1 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| Tensorflow | vitstr_small | (32, 128, 3) | 21.4 M | 83.01 | 83.84 | 86.57 | 87.00 | 2.0 | +| TensorFlow | vitstr_small | (32, 128, 3) | 21.4 M | 83.01 | 83.84 | 86.57 | 87.00 | 2.0 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ | TensorFlow | vitstr_base | (32, 128, 3) | 85.2 M | 85.98 | 86.70 | 90.47 | 90.95 | 5.8 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ @@ -157,15 +146,15 @@ For a comprehensive comparison, we have compiled a detailed benchmark on publicl +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ | PyTorch | crnn_mobilenet_v3_large | (32, 128, 3) | 4.5 M | 87.38 | 88.09 | 94.46 | 94.92 | 0.08 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | master | (32, 128, 3) | 58.7 M | | | | | 17.6 | +| PyTorch | master | (32, 128, 3) | 58.7 M | 88.57 | 89.39 | 95.73 | 96.21 | 17.6 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | sar_resnet31 | (32, 128, 3) | 55.4 M | | | | | 4.9 | +| PyTorch | sar_resnet31 | (32, 128, 3) | 55.4 M | 88.10 | 88.88 | 94.83 | 95.29 | 4.9 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | vitstr_small | (32, 128, 3) | 21.4 M | | | | | 1.5 | +| PyTorch | vitstr_small | (32, 128, 3) | 21.4 M | 88.00 | 88.82 | 95.40 | 95.78 | 1.5 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | vitstr_base | (32, 128, 3) | 85.2 M | | | | | 4.1 | +| PyTorch | vitstr_base | (32, 128, 3) | 85.2 M | 88.33 | 89.09 | 95.32 | 95.71 | 4.1 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | parseq | (32, 128, 3) | 23.8 M | | | | | 2.2 | +| PyTorch | parseq | (32, 128, 3) | 23.8 M | 88.53 | 89.24 | 95.56 | 95.91 | 2.2 | +----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+ @@ -216,37 +205,33 @@ For a comprehensive comparison, we have compiled a detailed benchmark on publicl +================+==========================================================+============================+============+===============+ | **Backend** | **Architecture** | **Recall** | **Precision** | **Recall** | **Precision** | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| TensorFlow | db_resnet50 + crnn_vgg16_bn | 70.82 | 75.56 | 83.97 | 81.40 | -+----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| TensorFlow | db_resnet50 + crnn_mobilenet_v3_small | 69.63 | 74.29 | 81.08 | 78.59 | -+----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| TensorFlow | db_resnet50 + crnn_mobilenet_v3_large | 70.01 | 74.70 | 83.28 | 80.73 | +| TensorFlow | db_resnet50 + crnn_vgg16_bn | 73.45 | 74.73 | 85.79 | 76.21 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| TensorFlow | db_resnet50 + sar_resnet31 | 68.75 | 73.76 | 78.56 | 76.24 | +| TensorFlow | db_resnet50 + crnn_mobilenet_v3_small | 72.66 | 73.93 | 83.43 | 74.11 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| TensorFlow | db_resnet50 + master | 68.75 | 73.76 | 78.56 | 76.24 | +| TensorFlow | db_resnet50 + crnn_mobilenet_v3_large | 72.86 | 74.13 | 85.16 | 75.65 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| TensorFlow | db_resnet50 + vitstr_small | 64.58 | 68.91 | 74.66 | 72.37 | +| TensorFlow | db_resnet50 + master | 72.73 | 74.00 | 84.13 | 75.05 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| TensorFlow | db_resnet50 + vitstr_base | 66.89 | 71.37 | 79.11 | 76.68 | +| TensorFlow | db_resnet50 + vitstr_small | 68.57 | 69.77 | 78.24 | 69.51 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| TensorFlow | db_resnet50 + parseq | 65.77 | 70.18 | 71.57 | 69.37 | +| TensorFlow | db_resnet50 + vitstr_base | 70.96 | 72.20 | 82.10 | 72.94 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| PyTorch | db_resnet50 + crnn_vgg16_bn | 67.82 | 73.35 | 84.84 | 83.27 | +| TensorFlow | db_resnet50 + parseq | 68.85 | 70.05 | 72.38 | 64.30 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| PyTorch | db_resnet50 + crnn_mobilenet_v3_small | 67.89 | 74.01 | 84.43 | 82.85 | +| PyTorch | db_resnet50 + crnn_vgg16_bn | 72.43 | 75.13 | 85.05 | 79.33 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| PyTorch | db_resnet50 + crnn_mobilenet_v3_large | 68.45 | 74.63 | 84.86 | 83.27 | +| PyTorch | db_resnet50 + crnn_mobilenet_v3_small | 73.06 | 75.79 | 84.64 | 78.94 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| PyTorch | db_resnet50 + sar_resnet31 | | | | | +| PyTorch | db_resnet50 + crnn_mobilenet_v3_large | 73.17 | 75.90 | 84.96 | 79.25 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| PyTorch | db_resnet50 + master | | | | | +| PyTorch | db_resnet50 + master | 73.90 | 76.66 | 85.84 | 80.07 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| PyTorch | db_resnet50 + vitstr_small | | | | | +| PyTorch | db_resnet50 + vitstr_small | 73.06 | 75.79 | 85.95 | 80.17 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| PyTorch | db_resnet50 + vitstr_base | | | | | +| PyTorch | db_resnet50 + vitstr_base | 73.70 | 76.46 | 85.76 | 79.99 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ -| PyTorch | db_resnet50 + parseq | | | | | +| PyTorch | db_resnet50 + parseq | 73.52 | 76.27 | 85.91 | 80.13 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ | None | Gvision text detection | 59.50 | 62.50 | 75.30 | 59.03 | +----------------+----------------------------------------------------------+------------+---------------+------------+---------------+ @@ -292,7 +277,7 @@ For instance, this snippet instantiates an end-to-end ocr_predictor working with .. code:: python3 from doctr.model import ocr_predictor - model = ocr_predictor('linknet_resnet18_rotation', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True) + model = ocr_predictor('linknet_resnet18', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True) What should I do with the output?