From 495151ce51524051c9b38e105a907ce95ed43bf3 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Wed, 12 Jun 2024 16:30:07 +0200 Subject: [PATCH] [Sync] page orientation integration (#16) --- README.md | 2 +- onnxtr/models/_utils.py | 73 +++++++++++++++++++++------- onnxtr/models/predictor/base.py | 62 +++++++++++++++++++++-- onnxtr/models/predictor/predictor.py | 26 +++++----- pyproject.toml | 5 +- tests/common/test_models.py | 21 +++++--- tests/common/test_models_zoo.py | 5 ++ 7 files changed, 153 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index ed1eac4..458a136 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ multi_img_doc = DocumentFile.from_images(["path/to/page1.jpg", "path/to/page2.jp ### Putting it together -Let's use the default pretrained model for an example: +Let's use the default `ocr_predictor` model for an example: ```python from onnxtr.io import DocumentFile diff --git a/onnxtr/models/_utils.py b/onnxtr/models/_utils.py index 89c03d7..2710141 100644 --- a/onnxtr/models/_utils.py +++ b/onnxtr/models/_utils.py @@ -11,6 +11,8 @@ import numpy as np from langdetect import LangDetectException, detect_langs +from onnxtr.utils.geometry import rotate_image + __all__ = ["estimate_orientation", "get_language"] @@ -29,42 +31,63 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float: return max(w / h, h / w) -def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> int: +def estimate_orientation( + img: np.ndarray, + general_page_orientation: Optional[Tuple[int, float]] = None, + n_ct: int = 70, + ratio_threshold_for_lines: float = 3, + min_confidence: float = 0.2, + lower_area: int = 100, +) -> int: """Estimate the angle of the general document orientation based on the lines of the document and the assumption that they should be horizontal. Args: ---- img: the img or bitmap to analyze (H, W, C) + general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence) + estimated by a model n_ct: the number of contours used for the orientation estimation ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines + min_confidence: the minimum confidence to consider the general_page_orientation + lower_area: the minimum area of a contour to be considered Returns: ------- - the angle of the general document orientation + the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation) """ assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported" - max_value = np.max(img) - min_value = np.min(img) - if max_value <= 1 and min_value >= 0 or (max_value <= 255 and min_value >= 0 and img.shape[-1] == 1): - thresh = img.astype(np.uint8) - if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3: + thresh = None + # Convert image to grayscale if necessary + if img.shape[-1] == 3: gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) gray_img = cv2.medianBlur(gray_img, 5) thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] - - # try to merge words in lines - (h, w) = img.shape[:2] - k_x = max(1, (floor(w / 100))) - k_y = max(1, (floor(h / 100))) - kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y)) - thresh = cv2.dilate(thresh, kernel, iterations=1) + else: + thresh = img.astype(np.uint8) # type: ignore[assignment] + + page_orientation, orientation_confidence = general_page_orientation or (None, 0.0) + if page_orientation and orientation_confidence >= min_confidence: + # We rotate the image to the general orientation which improves the detection + # No expand needed bitmap is already padded + thresh = rotate_image(thresh, -page_orientation) # type: ignore + else: # That's only required if we do not work on the detection models bin map + # try to merge words in lines + (h, w) = img.shape[:2] + k_x = max(1, (floor(w / 100))) + k_y = max(1, (floor(h / 100))) + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y)) + thresh = cv2.dilate(thresh, kernel, iterations=1) # extract contours contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - # Sort contours - contours = sorted(contours, key=get_max_width_length_ratio, reverse=True) + # Filter & Sort contours + contours = sorted( + [contour for contour in contours if cv2.contourArea(contour) > lower_area], + key=get_max_width_length_ratio, + reverse=True, + ) angles = [] for contour in contours[:n_ct]: @@ -75,10 +98,24 @@ def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_li angles.append(angle - 90) if len(angles) == 0: - return 0 # in case no angles is found + estimated_angle = 0 # in case no angles is found else: median = -median_low(angles) - return round(median) if abs(median) != 0 else 0 + estimated_angle = -round(median) if abs(median) != 0 else 0 + + # combine with the general orientation and the estimated angle + if page_orientation and orientation_confidence >= min_confidence: + # special case where the estimated angle is mostly wrong: + # case 1: - and + swapped + # case 2: estimated angle is completely wrong + # so in this case we prefer the general page orientation + if abs(estimated_angle) == abs(page_orientation): + return page_orientation + estimated_angle = estimated_angle if page_orientation == 0 else page_orientation + estimated_angle + if estimated_angle > 180: + estimated_angle -= 360 + + return estimated_angle # return the clockwise angle (negative - left side rotation, positive - right side rotation) def rectify_crops( diff --git a/onnxtr/models/predictor/base.py b/onnxtr/models/predictor/base.py index d6d2aa9..aa0a690 100644 --- a/onnxtr/models/predictor/base.py +++ b/onnxtr/models/predictor/base.py @@ -8,10 +8,10 @@ import numpy as np from onnxtr.models.builder import DocumentBuilder -from onnxtr.utils.geometry import extract_crops, extract_rcrops +from onnxtr.utils.geometry import extract_crops, extract_rcrops, rotate_image -from .._utils import rectify_crops, rectify_loc_preds -from ..classification import crop_orientation_predictor +from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds +from ..classification import crop_orientation_predictor, page_orientation_predictor from ..classification.predictor import OrientationPredictor from ..detection.zoo import ARCHS as DETECTION_ARCHS from ..recognition.zoo import ARCHS as RECOGNITION_ARCHS @@ -31,11 +31,14 @@ class _OCRPredictor: accordingly. Doing so will improve performances for documents with page-uniform rotations. preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding) symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False **kwargs: keyword args of `DocumentBuilder` """ crop_orientation_predictor: Optional[OrientationPredictor] + page_orientation_predictor: Optional[OrientationPredictor] def __init__( self, @@ -43,6 +46,7 @@ def __init__( straighten_pages: bool = False, preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, + detect_orientation: bool = False, load_in_8_bit: bool = False, **kwargs: Any, ) -> None: @@ -51,11 +55,63 @@ def __init__( self.crop_orientation_predictor = ( None if assume_straight_pages else crop_orientation_predictor(load_in_8_bit=load_in_8_bit) ) + self.page_orientation_predictor = ( + page_orientation_predictor(load_in_8_bit=load_in_8_bit) + if detect_orientation or straighten_pages or not assume_straight_pages + else None + ) self.doc_builder = DocumentBuilder(**kwargs) self.preserve_aspect_ratio = preserve_aspect_ratio self.symmetric_pad = symmetric_pad self.hooks: List[Callable] = [] + def _general_page_orientations( + self, + pages: List[np.ndarray], + ) -> List[Tuple[int, float]]: + _, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc] + # Flatten to list of tuples with (value, confidence) + page_orientations = [ + (orientation, prob) + for page_classes, page_probs in zip(classes, probs) + for orientation, prob in zip(page_classes, page_probs) + ] + return page_orientations + + def _get_orientations( + self, pages: List[np.ndarray], seg_maps: List[np.ndarray] + ) -> Tuple[List[Tuple[int, float]], List[int]]: + general_pages_orientations = self._general_page_orientations(pages) + origin_page_orientations = [ + estimate_orientation(seq_map, general_orientation) + for seq_map, general_orientation in zip(seg_maps, general_pages_orientations) + ] + return general_pages_orientations, origin_page_orientations + + def _straighten_pages( + self, + pages: List[np.ndarray], + seg_maps: List[np.ndarray], + general_pages_orientations: Optional[List[Tuple[int, float]]] = None, + origin_pages_orientations: Optional[List[int]] = None, + ) -> List[np.ndarray]: + general_pages_orientations = ( + general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages) + ) + origin_pages_orientations = ( + origin_pages_orientations + if origin_pages_orientations + else [ + estimate_orientation(seq_map, general_orientation) + for seq_map, general_orientation in zip(seg_maps, general_pages_orientations) + ] + ) + return [ + # We exapnd if the page is wider than tall and the angle is 90 or -90 + rotate_image(page, angle, expand=page.shape[1] > page.shape[0] and abs(angle) == 90) + for page, angle in zip(pages, origin_pages_orientations) + ] + @staticmethod def _generate_crops( pages: List[np.ndarray], diff --git a/onnxtr/models/predictor/predictor.py b/onnxtr/models/predictor/predictor.py index c400b91..ab54311 100644 --- a/onnxtr/models/predictor/predictor.py +++ b/onnxtr/models/predictor/predictor.py @@ -8,10 +8,10 @@ import numpy as np from onnxtr.io.elements import Document -from onnxtr.models._utils import estimate_orientation, get_language +from onnxtr.models._utils import get_language from onnxtr.models.detection.predictor import DetectionPredictor from onnxtr.models.recognition.predictor import RecognitionPredictor -from onnxtr.utils.geometry import detach_scores, rotate_image +from onnxtr.utils.geometry import detach_scores from onnxtr.utils.repr import NestedObject from .base import _OCRPredictor @@ -55,7 +55,13 @@ def __init__( self.det_predictor = det_predictor self.reco_predictor = reco_predictor _OCRPredictor.__init__( - self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs + self, + assume_straight_pages, + straighten_pages, + preserve_aspect_ratio, + symmetric_pad, + detect_orientation, + **kwargs, ) self.detect_orientation = detect_orientation self.detect_language = detect_language @@ -80,19 +86,17 @@ def __call__( for out_map in out_maps ] if self.detect_orientation: - origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps] + general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) orientations = [ - {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations + {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations ] else: orientations = None + general_pages_orientations = None + origin_pages_orientations = None if self.straighten_pages: - origin_page_orientations = ( - origin_page_orientations - if self.detect_orientation - else [estimate_orientation(seq_map) for seq_map in seg_maps] - ) - pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)] + pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) + # forward again to get predictions on straight pages loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment] diff --git a/pyproject.toml b/pyproject.toml index c320c7d..d213f20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,10 @@ changelog = "https://github.com/felixdittrich92/OnnxTR/releases" zip-safe = true [tool.setuptools.packages.find] -exclude = ["tests*", "scripts*"] +exclude = ["docs*", "tests*", "scripts*"] + +[tool.setuptools.package-data] +doctr = ["py.typed"] [tool.mypy] files = "onnxtr/" diff --git a/tests/common/test_models.py b/tests/common/test_models.py index be7bfab..cd93e34 100644 --- a/tests/common/test_models.py +++ b/tests/common/test_models.py @@ -33,28 +33,35 @@ def test_estimate_orientation(mock_image, mock_bitmap, mock_tilted_payslip): # test binarized image angle = estimate_orientation(mock_bitmap) - assert abs(angle - 30.0) < 1.0 + assert abs(angle) - 30 < 1.0 angle = estimate_orientation(mock_bitmap * 255) - assert abs(angle - 30.0) < 1.0 + assert abs(angle) - 30.0 < 1.0 angle = estimate_orientation(mock_image) - assert abs(angle - 30.0) < 1.0 + assert abs(angle) - 30.0 < 1.0 - rotated = geometry.rotate_image(mock_image, -angle) + rotated = geometry.rotate_image(mock_image, angle) angle_rotated = estimate_orientation(rotated) - assert abs(angle_rotated) < 1.0 + assert abs(angle_rotated) == 0 mock_tilted_payslip = reader.read_img_as_numpy(mock_tilted_payslip) - assert (estimate_orientation(mock_tilted_payslip) - 30.0) < 1.0 + assert estimate_orientation(mock_tilted_payslip) == -30 rotated = geometry.rotate_image(mock_tilted_payslip, -30, expand=True) angle_rotated = estimate_orientation(rotated) assert abs(angle_rotated) < 1.0 - with pytest.raises(AssertionError): estimate_orientation(np.ones((10, 10, 10))) + # test with general_page_orientation + assert estimate_orientation(mock_bitmap, (90, 0.9)) in range(140, 160) + + rotated = geometry.rotate_image(mock_tilted_payslip, -30) + assert estimate_orientation(rotated, (0, 0.9)) in range(-10, 10) + + assert estimate_orientation(mock_image, (0, 0.9)) - 30 < 1.0 + def test_get_lang(): sentence = "This is a test sentence." diff --git a/tests/common/test_models_zoo.py b/tests/common/test_models_zoo.py index 40f8e84..8937242 100644 --- a/tests/common/test_models_zoo.py +++ b/tests/common/test_models_zoo.py @@ -56,8 +56,13 @@ def test_ocrpredictor(mock_pdf, assume_straight_pages, straighten_pages): if assume_straight_pages: assert predictor.crop_orientation_predictor is None + if predictor.detect_orientation or predictor.straighten_pages: + assert isinstance(predictor.page_orientation_predictor, NestedObject) + else: + assert predictor.page_orientation_predictor is None else: assert isinstance(predictor.crop_orientation_predictor, NestedObject) + assert isinstance(predictor.page_orientation_predictor, NestedObject) out = predictor(doc) assert isinstance(out, Document)