diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..24027a2 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4cc9a01 --- /dev/null +++ b/.gitignore @@ -0,0 +1,165 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +.vscode/ +images/ +models/ +logs/ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1c50bb3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +torch +torchvision +torchaudio +Pillow +opencv-python==4.5.1.48 +tqdm diff --git a/result/.DS_Store b/result/.DS_Store new file mode 100644 index 0000000..02455de Binary files /dev/null and b/result/.DS_Store differ diff --git a/result/1AD7899.jpg b/result/1AD7899.jpg new file mode 100644 index 0000000..2e17755 Binary files /dev/null and b/result/1AD7899.jpg differ diff --git a/result/8273BXG7D.jpg b/result/8273BXG7D.jpg new file mode 100644 index 0000000..1d1e203 Binary files /dev/null and b/result/8273BXG7D.jpg differ diff --git a/result/A1214YZ.jpg b/result/A1214YZ.jpg new file mode 100644 index 0000000..07520d2 Binary files /dev/null and b/result/A1214YZ.jpg differ diff --git a/result/A1227BJ.jpg b/result/A1227BJ.jpg new file mode 100644 index 0000000..9503ecd Binary files /dev/null and b/result/A1227BJ.jpg differ diff --git a/result/A1251RP.jpg b/result/A1251RP.jpg new file mode 100644 index 0000000..1ce832f Binary files /dev/null and b/result/A1251RP.jpg differ diff --git a/result/A1260YD.jpg b/result/A1260YD.jpg new file mode 100644 index 0000000..137c90f Binary files /dev/null and b/result/A1260YD.jpg differ diff --git a/result/A1758BO.jpg b/result/A1758BO.jpg new file mode 100644 index 0000000..d5aa2e6 Binary files /dev/null and b/result/A1758BO.jpg differ diff --git a/result/A1788WC.jpg b/result/A1788WC.jpg new file mode 100644 index 0000000..5707b8d Binary files /dev/null and b/result/A1788WC.jpg differ diff --git a/result/A685PJ.jpg b/result/A685PJ.jpg new file mode 100644 index 0000000..1769f97 Binary files /dev/null and b/result/A685PJ.jpg differ diff --git a/result/A907R.jpg b/result/A907R.jpg new file mode 100644 index 0000000..2a20d25 Binary files /dev/null and b/result/A907R.jpg differ diff --git a/result/AB1092TY.jpg b/result/AB1092TY.jpg new file mode 100644 index 0000000..8dce388 Binary files /dev/null and b/result/AB1092TY.jpg differ diff --git a/result/AB1095ZU.jpg b/result/AB1095ZU.jpg new file mode 100644 index 0000000..00f1396 Binary files /dev/null and b/result/AB1095ZU.jpg differ diff --git a/result/AB1398GN.jpg b/result/AB1398GN.jpg new file mode 100644 index 0000000..131b2e2 Binary files /dev/null and b/result/AB1398GN.jpg differ diff --git a/result/AB1514HA.jpg b/result/AB1514HA.jpg new file mode 100644 index 0000000..da36ad2 Binary files /dev/null and b/result/AB1514HA.jpg differ diff --git a/result/AB1653DA.jpg b/result/AB1653DA.jpg new file mode 100644 index 0000000..659fce7 Binary files /dev/null and b/result/AB1653DA.jpg differ diff --git a/result/AB167ZIZ.jpg b/result/AB167ZIZ.jpg new file mode 100644 index 0000000..c874896 Binary files /dev/null and b/result/AB167ZIZ.jpg differ diff --git a/result/AB19661N.jpg b/result/AB19661N.jpg new file mode 100644 index 0000000..def5d9b Binary files /dev/null and b/result/AB19661N.jpg differ diff --git a/result/AB8027K.jpg b/result/AB8027K.jpg new file mode 100644 index 0000000..96ee6a2 Binary files /dev/null and b/result/AB8027K.jpg differ diff --git a/result/AB8966U.jpg b/result/AB8966U.jpg new file mode 100644 index 0000000..901b877 Binary files /dev/null and b/result/AB8966U.jpg differ diff --git a/result/AD1345UB.jpg b/result/AD1345UB.jpg new file mode 100644 index 0000000..09f7d3a Binary files /dev/null and b/result/AD1345UB.jpg differ diff --git a/result/AD8819RN.jpg b/result/AD8819RN.jpg new file mode 100644 index 0000000..244b755 Binary files /dev/null and b/result/AD8819RN.jpg differ diff --git a/result/AD8907BE.jpg b/result/AD8907BE.jpg new file mode 100644 index 0000000..7c7fa39 Binary files /dev/null and b/result/AD8907BE.jpg differ diff --git a/result/AD8946BE.jpg b/result/AD8946BE.jpg new file mode 100644 index 0000000..e101b7c Binary files /dev/null and b/result/AD8946BE.jpg differ diff --git a/result/AD9238LU.jpg b/result/AD9238LU.jpg new file mode 100644 index 0000000..a10fe7a Binary files /dev/null and b/result/AD9238LU.jpg differ diff --git a/result/AD9312DS.jpg b/result/AD9312DS.jpg new file mode 100644 index 0000000..0993ebc Binary files /dev/null and b/result/AD9312DS.jpg differ diff --git a/result/AD9679UB.jpg b/result/AD9679UB.jpg new file mode 100644 index 0000000..2b8b962 Binary files /dev/null and b/result/AD9679UB.jpg differ diff --git a/result/AG8646UF.jpg b/result/AG8646UF.jpg new file mode 100644 index 0000000..c12b56a Binary files /dev/null and b/result/AG8646UF.jpg differ diff --git a/result/AG8T35RN.jpg b/result/AG8T35RN.jpg new file mode 100644 index 0000000..89e28cb Binary files /dev/null and b/result/AG8T35RN.jpg differ diff --git a/result/AG9402JUK.jpg b/result/AG9402JUK.jpg new file mode 100644 index 0000000..60e056f Binary files /dev/null and b/result/AG9402JUK.jpg differ diff --git a/result/AG9575PG.jpg b/result/AG9575PG.jpg new file mode 100644 index 0000000..8569c3d Binary files /dev/null and b/result/AG9575PG.jpg differ diff --git a/result/B100VV.jpg b/result/B100VV.jpg new file mode 100644 index 0000000..b0135e1 Binary files /dev/null and b/result/B100VV.jpg differ diff --git a/result/B212LOH.jpg b/result/B212LOH.jpg new file mode 100644 index 0000000..bb0a0a0 Binary files /dev/null and b/result/B212LOH.jpg differ diff --git a/result/B2412PBA.jpg b/result/B2412PBA.jpg new file mode 100644 index 0000000..dc4fbd3 Binary files /dev/null and b/result/B2412PBA.jpg differ diff --git a/result/B2417BRT.jpg b/result/B2417BRT.jpg new file mode 100644 index 0000000..90dae60 Binary files /dev/null and b/result/B2417BRT.jpg differ diff --git a/result/B242LRH.jpg b/result/B242LRH.jpg new file mode 100644 index 0000000..7edf746 Binary files /dev/null and b/result/B242LRH.jpg differ diff --git a/result/B2438SIH.jpg b/result/B2438SIH.jpg new file mode 100644 index 0000000..640aeba Binary files /dev/null and b/result/B2438SIH.jpg differ diff --git a/result/B2461SYN.jpg b/result/B2461SYN.jpg new file mode 100644 index 0000000..7b7d282 Binary files /dev/null and b/result/B2461SYN.jpg differ diff --git a/result/B2469SOD.jpg b/result/B2469SOD.jpg new file mode 100644 index 0000000..7dcb8b0 Binary files /dev/null and b/result/B2469SOD.jpg differ diff --git a/result/B2477TBJ.jpg b/result/B2477TBJ.jpg new file mode 100644 index 0000000..d9ca0fe Binary files /dev/null and b/result/B2477TBJ.jpg differ diff --git a/result/B2487SLY.jpg b/result/B2487SLY.jpg new file mode 100644 index 0000000..146fdaa Binary files /dev/null and b/result/B2487SLY.jpg differ diff --git a/result/B2501KOB.jpg b/result/B2501KOB.jpg new file mode 100644 index 0000000..51e48f5 Binary files /dev/null and b/result/B2501KOB.jpg differ diff --git a/result/B2514SON.jpg b/result/B2514SON.jpg new file mode 100644 index 0000000..1087585 Binary files /dev/null and b/result/B2514SON.jpg differ diff --git a/result/B2516FFX.jpg b/result/B2516FFX.jpg new file mode 100644 index 0000000..9a1ed73 Binary files /dev/null and b/result/B2516FFX.jpg differ diff --git a/result/B2519TIU.jpg b/result/B2519TIU.jpg new file mode 100644 index 0000000..3aeff2d Binary files /dev/null and b/result/B2519TIU.jpg differ diff --git a/result/B2531SKS.jpg b/result/B2531SKS.jpg new file mode 100644 index 0000000..892478b Binary files /dev/null and b/result/B2531SKS.jpg differ diff --git a/result/B2590SBA.jpg b/result/B2590SBA.jpg new file mode 100644 index 0000000..c5475bc Binary files /dev/null and b/result/B2590SBA.jpg differ diff --git a/result/B25ZL.jpg b/result/B25ZL.jpg new file mode 100644 index 0000000..73afa93 Binary files /dev/null and b/result/B25ZL.jpg differ diff --git a/result/B2602BFE.jpg b/result/B2602BFE.jpg new file mode 100644 index 0000000..b584036 Binary files /dev/null and b/result/B2602BFE.jpg differ diff --git a/result/B2603TTE.jpg b/result/B2603TTE.jpg new file mode 100644 index 0000000..5528503 Binary files /dev/null and b/result/B2603TTE.jpg differ diff --git a/result/B2616TOE.jpg b/result/B2616TOE.jpg new file mode 100644 index 0000000..a9e9837 Binary files /dev/null and b/result/B2616TOE.jpg differ diff --git a/result/B2635TYM.jpg b/result/B2635TYM.jpg new file mode 100644 index 0000000..5027617 Binary files /dev/null and b/result/B2635TYM.jpg differ diff --git a/result/B2641UZD.jpg b/result/B2641UZD.jpg new file mode 100644 index 0000000..d020d4a Binary files /dev/null and b/result/B2641UZD.jpg differ diff --git a/result/B2649TGZ.jpg b/result/B2649TGZ.jpg new file mode 100644 index 0000000..510b473 Binary files /dev/null and b/result/B2649TGZ.jpg differ diff --git a/result/B2658KFX.jpg b/result/B2658KFX.jpg new file mode 100644 index 0000000..70bf394 Binary files /dev/null and b/result/B2658KFX.jpg differ diff --git a/result/B2695SZL.jpg b/result/B2695SZL.jpg new file mode 100644 index 0000000..2c111c4 Binary files /dev/null and b/result/B2695SZL.jpg differ diff --git a/result/B2707UKT.jpg b/result/B2707UKT.jpg new file mode 100644 index 0000000..ace929b Binary files /dev/null and b/result/B2707UKT.jpg differ diff --git a/result/B2708SZT.jpg b/result/B2708SZT.jpg new file mode 100644 index 0000000..52a7fc5 Binary files /dev/null and b/result/B2708SZT.jpg differ diff --git a/result/B2709RR.jpg b/result/B2709RR.jpg new file mode 100644 index 0000000..73fb57a Binary files /dev/null and b/result/B2709RR.jpg differ diff --git a/result/B2741POD.jpg b/result/B2741POD.jpg new file mode 100644 index 0000000..8fc814d Binary files /dev/null and b/result/B2741POD.jpg differ diff --git a/result/B2741TZA.jpg b/result/B2741TZA.jpg new file mode 100644 index 0000000..15d211f Binary files /dev/null and b/result/B2741TZA.jpg differ diff --git a/result/B2744PKX.jpg b/result/B2744PKX.jpg new file mode 100644 index 0000000..c0c819b Binary files /dev/null and b/result/B2744PKX.jpg differ diff --git a/result/B2745SZV.jpg b/result/B2745SZV.jpg new file mode 100644 index 0000000..ccc4571 Binary files /dev/null and b/result/B2745SZV.jpg differ diff --git a/result/B2770TRN.jpg b/result/B2770TRN.jpg new file mode 100644 index 0000000..5bc37de Binary files /dev/null and b/result/B2770TRN.jpg differ diff --git a/result/B2772KFI.jpg b/result/B2772KFI.jpg new file mode 100644 index 0000000..d4d10da Binary files /dev/null and b/result/B2772KFI.jpg differ diff --git a/result/B2785KF1.jpg b/result/B2785KF1.jpg new file mode 100644 index 0000000..4647294 Binary files /dev/null and b/result/B2785KF1.jpg differ diff --git a/result/B2811KKB.jpg b/result/B2811KKB.jpg new file mode 100644 index 0000000..44a7fb9 Binary files /dev/null and b/result/B2811KKB.jpg differ diff --git a/result/B2815SYR.jpg b/result/B2815SYR.jpg new file mode 100644 index 0000000..2639ee7 Binary files /dev/null and b/result/B2815SYR.jpg differ diff --git a/result/B2820KKL.jpg b/result/B2820KKL.jpg new file mode 100644 index 0000000..371ee78 Binary files /dev/null and b/result/B2820KKL.jpg differ diff --git a/result/B2833BYW.jpg b/result/B2833BYW.jpg new file mode 100644 index 0000000..35d709c Binary files /dev/null and b/result/B2833BYW.jpg differ diff --git a/result/B2863UOB.jpg b/result/B2863UOB.jpg new file mode 100644 index 0000000..5beba9f Binary files /dev/null and b/result/B2863UOB.jpg differ diff --git a/result/B2866.jpg b/result/B2866.jpg new file mode 100644 index 0000000..eeb928f Binary files /dev/null and b/result/B2866.jpg differ diff --git a/result/B2868BRE.jpg b/result/B2868BRE.jpg new file mode 100644 index 0000000..a7bb991 Binary files /dev/null and b/result/B2868BRE.jpg differ diff --git a/result/B308UPZ.jpg b/result/B308UPZ.jpg new file mode 100644 index 0000000..ef53225 Binary files /dev/null and b/result/B308UPZ.jpg differ diff --git a/result/B369SRZ.jpg b/result/B369SRZ.jpg new file mode 100644 index 0000000..7663ef2 Binary files /dev/null and b/result/B369SRZ.jpg differ diff --git a/result/B501CAM.jpg b/result/B501CAM.jpg new file mode 100644 index 0000000..d6b5c67 Binary files /dev/null and b/result/B501CAM.jpg differ diff --git a/result/B555UBU.jpg b/result/B555UBU.jpg new file mode 100644 index 0000000..f6e9ab4 Binary files /dev/null and b/result/B555UBU.jpg differ diff --git a/result/B678RKZ.jpg b/result/B678RKZ.jpg new file mode 100644 index 0000000..885c284 Binary files /dev/null and b/result/B678RKZ.jpg differ diff --git a/result/B762PD.jpg b/result/B762PD.jpg new file mode 100644 index 0000000..c27f1f8 Binary files /dev/null and b/result/B762PD.jpg differ diff --git a/result/B806NB.jpg b/result/B806NB.jpg new file mode 100644 index 0000000..f8a06f4 Binary files /dev/null and b/result/B806NB.jpg differ diff --git a/result/B811MZS.jpg b/result/B811MZS.jpg new file mode 100644 index 0000000..d579410 Binary files /dev/null and b/result/B811MZS.jpg differ diff --git a/result/B889YPL.jpg b/result/B889YPL.jpg new file mode 100644 index 0000000..3f0d0b2 Binary files /dev/null and b/result/B889YPL.jpg differ diff --git a/result/BZ648SF7.jpg b/result/BZ648SF7.jpg new file mode 100644 index 0000000..865eb96 Binary files /dev/null and b/result/BZ648SF7.jpg differ diff --git a/result/E2519.jpg b/result/E2519.jpg new file mode 100644 index 0000000..ba32d3f Binary files /dev/null and b/result/E2519.jpg differ diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000..bbb945b Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/apps/__init__.py b/src/apps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/apps/char_detection.py b/src/apps/char_detection.py new file mode 100644 index 0000000..1bd48f0 --- /dev/null +++ b/src/apps/char_detection.py @@ -0,0 +1,123 @@ +''' +@Author : Ali Mustofa HALOTEC +@Module : Character Detection Faster RCNN +@Created on : 19 Jul 2022 +''' +#!/usr/bin/env python3 +# Path: src/apps/char_detection.py +import os +import cv2 +import numpy as np +from PIL import Image +from src.utils.utils import download_and_unzip_model +import torch +import torchvision +from torchvision import transforms +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor + + +class CharDetection: + + def __init__(self, root_path:str, model_config:dict) -> None: + ''' + Load model + @params: + - root_path:str -> root of path model + - model_config:dict -> config of model {filename, classes, url, file_size} + ''' + self.model_name = f'{root_path}/{model_config["filename"]}' + self.classes = model_config['classes'] + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self.__check_model() + self.model = self.__load_model() + + def __check_model(self, root_path:str, model_config:dict) -> None: + if not os.path.isfile(self.model_name): + download_and_unzip_model( + root_dir = root_path, + name = model_config['filename'], + url = model_config['url'], + file_size = model_config['file_size'], + unzip = False + ) + else: print('Load model') + + @staticmethod + def __image_transform(image) -> torch.Tensor: + return transforms.Compose([transforms.ToTensor()])(image) + + def __load_model(self) -> torch.nn.Module: + model = self.__fasterrcnn_resnet50_fpn() + model.load_state_dict(torch.load(self.model_name, map_location=self.device), False) + model.to(self.device) + return model.eval() + + def __fasterrcnn_resnet50_fpn(self)-> torch.nn.Module: + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + in_features = model.roi_heads.box_predictor.cls_score.in_features + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(self.classes)+1) + return model + + @staticmethod + def __filter_threshold(probs:dict, threshold:float) -> dict: + num_filtered = (probs['scores']>threshold).float() + keep = (num_filtered == torch.tensor(1)).nonzero().flatten() + final_probs = probs + final_probs['boxes'] = final_probs['boxes'][keep] + final_probs['scores'] = final_probs['scores'][keep] + final_probs['labels'] = final_probs['labels'][keep] + return final_probs + + @staticmethod + def __original_boxes(boxes:torch.Tensor, img_size:tuple,resized:int) -> torch.Tensor: + image_width, image_height = img_size[1], img_size[0] + boxes = torch.tensor([[ + (x_min/resized)*image_width, (y_min/resized)*image_height, \ + (x_max/resized)*image_width, (y_max/resized)*image_height] \ + for (x_min, y_min, x_max, y_max) in boxes.cpu().numpy()]) + return boxes + + @staticmethod + def __sort_by_boxes(probs:dict) -> dict: + x_min_list = [i[0] for i in probs['boxes']] + idx = [x_min_list.index(x) for x in sorted(x_min_list)] + probs['boxes'] = probs['boxes'][idx] + probs['scores'] = probs['scores'][idx] + probs['labels'] = probs['labels'][idx] + return probs + + def detect(self, image:np.array, size:int = None, + boxes_ori:bool = False, threshold:float = 0.5, sorted:bool = True) -> dict: + ''' + @params: + - image: numpy array of image + - size: int of image resize + - boxes_ori: bool of original boxes + - threshold: float of threshold + - sorted: bool of sorted by boxes + @return: + probs: dict of probs -> { + 'boxes' : [x_min, y_min, x_max, y_max], + 'scores': [float], + 'labels': [int] + } + ''' + im_shape = (image.shape[0], image.shape[1]) + image = cv2.resize(image, (size,size)) if size else image + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + image = self.__image_transform(image) + with torch.no_grad(): + probs = self.model([image])[0] + probs = self.__filter_threshold(probs, threshold) + if boxes_ori and size: + probs['boxes'] = self.__original_boxes(probs['boxes'],im_shape, size) + if sorted: + probs = self.__sort_by_boxes(probs) + return {k: v.cpu().numpy() for k, v in probs.items()} + + +if __name__ == '__main__': + char_detection = CharDetection('./models/text_detection.ali', ['text']) + image = cv2.imread('./images/1.jpg') + results = char_detection.detect(image, size=244, boxes_ori=True, threshold=0.01) + print(results) \ No newline at end of file diff --git a/src/apps/char_recognition.py b/src/apps/char_recognition.py new file mode 100644 index 0000000..91c2a35 --- /dev/null +++ b/src/apps/char_recognition.py @@ -0,0 +1,143 @@ +''' +@Author : Ali Mustofa HALOTEC +@Module : Character Recognition Neural Network +@Created on : 20 Jul 2022 +''' +#!/usr/bin/env python3 +# Path: src/apps/char_recognition.py + +import os +import cv2 +import numpy as np +from PIL import Image +from src.utils.utils import download_and_unzip_model +import torch +import torch.nn as nn +from torchvision import transforms + +class _NeuralNetwork(nn.Module): + def __init__(self, num_classes): + super(_NeuralNetwork, self).__init__() + + self.conv1 = nn.Sequential( + nn.Conv2d(3, 32, 3, padding=1), + nn.ReLU(), + nn.BatchNorm2d(32), + nn.Conv2d(32, 32, 3, stride=2, padding=1), + nn.ReLU(), + nn.BatchNorm2d(32), + nn.MaxPool2d(2, 2), + nn.Dropout(0.25) + ) + + self.conv2 = nn.Sequential( + nn.Conv2d(32, 64, 3, padding=1), + nn.ReLU(), + nn.BatchNorm2d(64), + nn.Conv2d(64, 64, 3, stride=2, padding=1), + nn.ReLU(), + nn.BatchNorm2d(64), + nn.MaxPool2d(2, 2), + nn.Dropout(0.25) + ) + + self.conv3 = nn.Sequential( + nn.Conv2d(64, 128, 3, padding=1), + nn.ReLU(), + nn.BatchNorm2d(128), + nn.MaxPool2d(2, 2), + nn.Dropout(0.25) + ) + + self.fc = nn.Sequential( + nn.Linear(128, num_classes), + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + + x = x.view(x.size(0), -1) + return self.fc(x) + +class CharRecognition: + + def __init__(self, root_path:str, model_config:dict) -> None: + ''' + Load model + @params: + - model_name: str of model name + - classes: list of classes + ''' + self.model_name = f'{root_path}/{model_config["filename"]}' + self.classes = model_config['classes'] + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self.__check_model(root_path, model_config) + self.model = self.__load_model() + + def __check_model(self, root_path:str, model_config:dict) -> None: + if not os.path.isfile(self.model_name): + download_and_unzip_model( + root_dir = root_path, + name = model_config['filename'], + url = model_config['url'], + file_size = model_config['file_size'], + unzip = False + ) + else: print('Load model') + + def __load_model(self) -> nn.Module: + ''' + Load model from file + @return: + - model: nn.Module + ''' + model = _NeuralNetwork(len(self.classes)) + model.load_state_dict(torch.load(self.model_name, map_location=self.device)) + model.to(self.device) + return model.eval() + + @staticmethod + def __image_transform(image) -> torch.Tensor: + return transforms.Compose([ + transforms.Resize(size=(31,31)), + transforms.CenterCrop(size=31), + transforms.ToTensor(), + transforms.Grayscale(3), + transforms.Normalize(mean=(0.5,), std=(0.5,)) + ])(image) + + def recognition(self, image:np.array) -> dict: + ''' + Recognize character from image + @params: + - image: np.array + @return: + - result: dict -> {class: recognition, prob: confidence} + ''' + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + image = self.__image_transform(image) + if torch.device('cuda') == self.device: + image = image.view(1, 3, 31, 31).cuda() + else: + image = image.view(1, 3, 31, 31) + + with torch.no_grad(): + output = self.model(image) + + output = nn.functional.log_softmax(output, dim=1) + output = torch.exp(output) + prob, top_class = torch.topk(output, k=1, dim=1) + res_class = self.classes[top_class.cpu().numpy()[0][0]] + res_prob = round((prob.cpu().numpy()[0][0]), 2) + return { + 'text': res_class, + 'conf': res_prob + } + +if __name__ == '__main__': + char_recog = CharRecognition('./models/text_recognition.ali') + image = cv2.imread('./images/1_10043.jpg') + result = char_recog.recognition(image) + print(result) \ No newline at end of file diff --git a/src/apps/ocr.py b/src/apps/ocr.py new file mode 100644 index 0000000..389c25d --- /dev/null +++ b/src/apps/ocr.py @@ -0,0 +1,131 @@ +import cv2 +from configs.models import * +import numpy as np + +class Ocr: + def __init__(self, detection:str = None, recog:str = None) -> None: + self.detection = detection + self.recog = recog + if detection: + from char_detection import CharDetection + self.detection_model = CharDetection(model_name=detection, classes=['text']) + if recog: + from char_recognition import CharRecognition + self.recog_model = CharRecognition(model_name=recog) + + def char_detection(self, image:np.array, image_size:int = 244, + threshold:float = 0.5, boxes_ori:bool = True, det_sorted:bool = True) -> dict: + ''' + Detect character from image + @params: + - image: np.array -> image to be detected + - image_size: int -> size of image to be detected + - threshold: float -> threshold for detection + - boxes_ori: bool -> if True, return boxes in original image + - det_sorted: bool -> if True, return boxes in sorted order + @return: + - result: {'boxes': np.array, 'confidences': np.array, 'labels': np.array} + ''' + # assert error if model is not loaded + assert self.detection, 'Model is not loaded' + + result_det = self.detection_model.detect(image, image_size, + boxes_ori, threshold, sorted=det_sorted) + return result_det + + def char_recognition(self, image: np.array) -> dict: + ''' + Read single character from image + @params: + - image: np.array -> image to be read + @return: + - result: {'text': str, 'conf': float} + ''' + # assert error if model is not loaded + assert self.recog, 'Model is not loaded' + + return self.recog_model.recognition(image) + + def __calculate_confidence(self, result:dict) -> float: + return round(sum([i['conf'] for i in result])/len(result),2) + + def __marger_text(self, result:dict) -> str: + return ''.join([i['text'] for i in result]) + + def visualize_result(self, image:np.array, results:list) -> np.array: + ''' + Visualize result of OCR + @params: + - image: np.array -> image to be draw + - results: list -> result of OCR(output type advanced) + @return: + - image: np.array -> image with result + ''' + # Draw boxes + for box in results: + x_min, y_min, x_max, y_max = box['box'] + cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 1) + # Draw text + text = box['text'] + cv2.putText(image, text, (x_min, y_min), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2) + return image + + def ocr(self, image:np.array, det_size:int = 244, boxes_ori:bool = True, + det_threshold:float=0.5, det_sorted:bool=True, output_type:str='normal') -> None: + ''' + Read text from image using Text Detection and Recognition + @params: + - image: np.array -> image to be read + - det_size: int -> size of image to be detected + - boxes_ori: bool -> if True, return boxes in original image + - det_threshold: float -> threshold for detection + - det_sorted: bool -> if True, return boxes in sorted order + - output_type: str -> 'normal' or 'advanced' + @return: + - result: result of detection and recognition + - normal : {'text': str, 'conf': float} + - advanced : [{'text': str, 'conf': float, 'box': tuple}] + ''' + # assert error if output type not in ['normal', 'advanced'] + assert output_type in ['normal', 'advanced'], 'Output type is not valid' + # Char detection + res_detection = self.char_detection(image=image, image_size=det_size, + threshold=det_threshold, boxes_ori=boxes_ori, det_sorted=det_sorted) + boxes = res_detection['boxes'].astype(int) + + # Char recognition + print(res_detection) + result_recognition = list() + for box in boxes: + x_min, y_min, x_max, y_max = box + image_crop = image[y_min:y_max, x_min:x_max] + res_recognition = self.char_recognition(image_crop) + if output_type == 'normal': + result_recognition.append(res_recognition) + elif output_type == 'advanced': + result_recognition.append({ + 'text': res_recognition['text'], + 'conf': res_recognition['conf'], + 'box': box}) + + # Output type + if output_type == 'normal': + confidence = self.__calculate_confidence(result_recognition) + text = self.__marger_text(result_recognition) + result = {'confidence': confidence, 'text': text} + elif output_type == 'advanced': + result =result_recognition + return result + +if __name__ == '__main__': + import glob + ocr = Ocr(detection='./models/text_detection.ali', recog='./models/text_recognition.ali') + for i in glob.glob('/Users/alimustofa/Halotec/Datasets/JASAMARGA/REPORT/LPR/old_images/*.jpg'): + image = cv2.imread(i) + + result = ocr.ocr(image, output_type='advanced', det_threshold=0.9) + text_ocr = ''.join([i['text'] for i in result]) + cv2.imwrite(text_ocr+'.jpg', ocr.visualize_result(image, result)) + print( + ''.join([i['text'] for i in result]), + ) \ No newline at end of file diff --git a/src/configs/__init__.py b/src/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/configs/models.py b/src/configs/models.py new file mode 100644 index 0000000..6bd3505 --- /dev/null +++ b/src/configs/models.py @@ -0,0 +1,25 @@ +import os +import string + +#========================== DIRECTORY Models ===================================== +ROOT = os.path.normpath(os.path.dirname(__file__)) + +DIRECTORY_MODEL = os.path.expanduser('~/.Halotec/Models') + +DIRECTORY_LOGGER = os.path.expanduser('~/.Halotec/logger') + +#============================ MODELS ====================================== +MODELS = { + 'char_recognition' : { + 'filename' : 'char_recognition.ali', + 'classes' : string.digits+string.ascii_uppercase, + 'url' : 'https://huggingface.co/spaces/Alimustoofaa/ocr-license-plate-indonesia/resolve/main/saved_model/models.zip', + 'file_size' : 8326131 + }, + 'char_detection' : { + 'filename': 'char_detection.ali', + 'classes': ['text'], + 'url' : 'https://github.com/Alimustoofaa/1-PlateDetection/releases/download/plate_detection_v2/plate_detection_v2.pt', + 'file_size' : 14753191 + }, +} diff --git a/src/configs/ocr.py b/src/configs/ocr.py new file mode 100644 index 0000000..6a9a030 --- /dev/null +++ b/src/configs/ocr.py @@ -0,0 +1,5 @@ +import string + +LABEL = string.digits+string.ascii_uppercase +label_dict = {idx : label for idx, label in enumerate(LABEL)} +num_classes = len(label_dict) \ No newline at end of file diff --git a/src/networks/__init__.py b/src/networks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/networks/ocr.py b/src/networks/ocr.py new file mode 100644 index 0000000..2ae0633 --- /dev/null +++ b/src/networks/ocr.py @@ -0,0 +1,49 @@ +import torch.nn as nn +from configs.ocr import num_classes + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + + self.conv1 = nn.Sequential( + nn.Conv2d(3, 32, 3, padding=1), + nn.ReLU(), + nn.BatchNorm2d(32), + nn.Conv2d(32, 32, 3, stride=2, padding=1), + nn.ReLU(), + nn.BatchNorm2d(32), + nn.MaxPool2d(2, 2), + nn.Dropout(0.25) + ) + + self.conv2 = nn.Sequential( + nn.Conv2d(32, 64, 3, padding=1), + nn.ReLU(), + nn.BatchNorm2d(64), + nn.Conv2d(64, 64, 3, stride=2, padding=1), + nn.ReLU(), + nn.BatchNorm2d(64), + nn.MaxPool2d(2, 2), + nn.Dropout(0.25) + ) + + self.conv3 = nn.Sequential( + nn.Conv2d(64, 128, 3, padding=1), + nn.ReLU(), + nn.BatchNorm2d(128), + nn.MaxPool2d(2, 2), + nn.Dropout(0.25) + ) + + self.fc = nn.Sequential( + nn.Linear(128, num_classes), + ) + + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + + x = x.view(x.size(0), -1) + return self.fc(x) \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/utils.py b/src/utils/utils.py new file mode 100644 index 0000000..b954d17 --- /dev/null +++ b/src/utils/utils.py @@ -0,0 +1,73 @@ +import os +import cv2 +import base64 +import requests +import numpy as np +from tqdm import tqdm +from pathlib import Path +from zipfile import ZipFile + +def download_and_unzip_model(root_dir:str, name:str, + url:str, file_size:int, unzip:bool = False): + ''' + Checking model in model_path + download model if file not found + @params: + root_dir(str): The root directory of model. + name(str): The name of model. + url(str): The url of model. + file_size(int): The size of model. + unzip(bool): Unzip the model or not. + ''' + Path(root_dir).mkdir(parents=True, exist_ok=True) + + # check if model is already or not + print(f'Downloading {root_dir.split("/")[-1]} model, please wait.') + response = requests.get(url, stream=True) + + progress = tqdm(response.iter_content(1024), + f'Downloading model', + total=file_size, unit='B', + unit_scale=True, unit_divisor=1024) + save_dir = f'{root_dir}/{name}' + with open(save_dir, 'wb') as f: + for data in progress: + f.write(data) + progress.update(len(data)) + print(f'Done downloading {root_dir.split("/")[-1]} model.') + + # unzip model + if unzip: + with ZipFile(save_dir, 'r') as zip_obj: + zip_obj.extractall(root_dir) + print(f'Done unzip {root_dir.split("/")[-1]} model.') + os.remove(save_dir) + +def encode_image2string(image): + image_list = cv2.imencode('.jpg', image)[1] + image_bytes = image_list.tobytes() + image_encoded = base64.b64encode(image_bytes) + return image_encoded + +def decode_string2image(image_encoded): + jpg_original = base64.b64decode(image_encoded) + jpg_as_np = np.frombuffer(jpg_original, dtype=np.uint8) + image = cv2.imdecode(jpg_as_np, flags=1) + return image + +def resize_image(image, size_percent): + ''' + Resize an image so that its longest edge equals to the given size. + Args: + image(cv2.Image): The input image. + size_percent(int): The size of longest edge. + Returns: + image(cv2.Image): The output image. + ''' + width = int(image.shape[1] * size_percent / 100) + height = int(image.shape[0] * size_percent / 100) + dim = (width, height) + + # resize image + resized = cv2.resize(image, dim, interpolation = cv2.INTER_AREA) + return resized \ No newline at end of file diff --git a/text_detection.ipynb b/text_detection.ipynb new file mode 100644 index 0000000..aeef9e5 --- /dev/null +++ b/text_detection.ipynb @@ -0,0 +1,592 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import cv2\n", + "import numpy as np\n", + "from PIL import Image\n", + "from matplotlib import pyplot as plt\n", + "import matplotlib.patches as patches\n", + "\n", + "\n", + "import torch\n", + "import torchvision\n", + "from torchvision import transforms\n", + "from torchvision.models.detection.faster_rcnn import FastRCNNPredictor\n", + "\n", + "import albumentations as A\n", + "from albumentations.pytorch.transforms import ToTensorV2\n", + "\n", + "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class CharDetection:\n", + " def __init__(self, model_name:str, classes:list) -> None:\n", + " self.model_name = model_name\n", + " self.classes = classes\n", + " self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", + " self.model = self.__load_model()\n", + "\n", + " @staticmethod\n", + " def __image_transform(image) -> torch.Tensor:\n", + " return transforms.Compose([transforms.ToTensor()])(image)\n", + "\n", + " def __load_model(self) -> torch.nn.Module:\n", + " model = self.__fasterrcnn_resnet50_fpn()\n", + " model.load_state_dict(torch.load(self.model_name, map_location=self.device), False)\n", + " model.to(self.device)\n", + " return model.eval()\n", + "\n", + " def __fasterrcnn_resnet50_fpn(self)-> torch.nn.Module:\n", + " model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)\n", + " in_features = model.roi_heads.box_predictor.cls_score.in_features\n", + " model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(self.classes)+1)\n", + " return model\n", + "\n", + " @staticmethod\n", + " def __filter_threshold(probs:dict, threshold:float) -> dict:\n", + " num_filtered = (probs['scores']>threshold).float()\n", + " keep = (num_filtered == torch.tensor(1)).nonzero().flatten()\n", + " final_probs = probs\n", + " final_probs['boxes'] = final_probs['boxes'][keep]\n", + " final_probs['scores'] = final_probs['scores'][keep]\n", + " final_probs['labels'] = final_probs['labels'][keep]\n", + " return final_probs\n", + "\n", + " @staticmethod\n", + " def __original_boxes(boxes:torch.Tensor, img_size:tuple,resized:int) -> torch.Tensor:\n", + " image_width, image_height = img_size[1], img_size[0]\n", + " boxes = torch.tensor([[\n", + " (x_min/resized)*image_width, (y_min/resized)*image_height, \\\n", + " (x_max/resized)*image_width, (y_max/resized)*image_height] \\\n", + " for (x_min, y_min, x_max, y_max) in boxes.cpu().numpy()])\n", + " return boxes\n", + " \n", + " @staticmethod\n", + " def __sort_by_boxes(probs:dict) -> dict:\n", + " x_min_list = [i[0] for i in probs['boxes']]\n", + " idx = [x_min_list.index(x) for x in sorted(x_min_list)]\n", + " probs['boxes'] = probs['boxes'][idx]\n", + " probs['scores'] = probs['scores'][idx]\n", + " probs['labels'] = probs['labels'][idx]\n", + " return probs \n", + "\n", + " def detect(self, image:np.array, size:int = None, \n", + " boxes_ori:bool = False, threshold:float = 0.5, sorted:bool = True) -> dict:\n", + " '''\n", + " Args:\n", + " - images : Numpy.Array ->\n", + " '''\n", + " im_shape = (image.shape[1], image.shape[0])\n", + " image = cv2.resize(image, (size,size)) if size else image\n", + " image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n", + " image = self.__image_transform(image)\n", + " with torch.no_grad():\n", + " probs = self.model([image])[0]\n", + " probs = self.__filter_threshold(probs, threshold)\n", + " if boxes_ori and size:\n", + " probs['boxes'] = self.__original_boxes(probs['boxes'],im_shape, size)\n", + " if sorted:\n", + " probs = self.__sort_by_boxes(probs)\n", + " return {k: v.cpu().numpy() for k, v in probs.items()}\n", + "\n", + "\n", + "\n", + "\n", + "char_detection = CharDetection('./models/text_detection.ali', ['text'])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'boxes': array([[11.86109069, 50.18588 , 17.78047777, 76.3388877 ],\n", + " [22.61613852, 50.53432627, 27.35313309, 77.12569721],\n", + " [23.92049417, 50.51673333, 35.74157014, 78.21390984],\n", + " [26.93010896, 51.20184805, 32.03692624, 78.96562426],\n", + " [31.03131979, 50.9152495 , 35.84480542, 79.37770781],\n", + " [32.11097877, 53.13152939, 34.70335989, 76.93529992],\n", + " [34.21225804, 51.55257194, 37.35505132, 78.26963218],\n", + " [35.52049355, 50.2777568 , 40.2434603 , 77.74667902],\n", + " [49.1127779 , 50.9270992 , 54.50939891, 79.81615429],\n", + " [54.97524605, 54.84159779, 59.88662901, 79.10287951]]),\n", + " 'labels': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),\n", + " 'scores': array([0.9523825 , 0.9936494 , 0.22991937, 0.9894872 , 0.9672788 ,\n", + " 0.05240851, 0.06529855, 0.98821217, 0.9909609 , 0.7661257 ],\n", + " dtype=float32)}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image = cv2.imread('./images/2.jpg')\n", + "char_result = char_detection.detect(image, size=244, boxes_ori=True, threshold=0.01)\n", + "char_result" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[11.86109069, 50.18588 , 17.78047777, 76.3388877 ],\n", + " [22.61613852, 50.53432627, 27.35313309, 77.12569721],\n", + " [23.92049417, 50.51673333, 35.74157014, 78.21390984],\n", + " [26.93010896, 51.20184805, 32.03692624, 78.96562426],\n", + " [31.03131979, 50.9152495 , 35.84480542, 79.37770781],\n", + " [32.11097877, 53.13152939, 34.70335989, 76.93529992],\n", + " [34.21225804, 51.55257194, 37.35505132, 78.26963218],\n", + " [35.52049355, 50.2777568 , 40.2434603 , 77.74667902],\n", + " [49.1127779 , 50.9270992 , 54.50939891, 79.81615429],\n", + " [54.97524605, 54.84159779, 59.88662901, 79.10287951]])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_min_list = [i[0] for i in char_result['boxes']]\n", + "idx = [x_min_list.index(x) for x in sorted(x_min_list)]\n", + "print(idx)\n", + "char_result['boxes'][idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "FasterRCNN(\n", + " (transform): GeneralizedRCNNTransform(\n", + " Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", + " Resize(min_size=(800,), max_size=1333, mode='bilinear')\n", + " )\n", + " (backbone): BackboneWithFPN(\n", + " (body): IntermediateLayerGetter(\n", + " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", + " (bn1): FrozenBatchNorm2d(64, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", + " (layer1): Sequential(\n", + " (0): Bottleneck(\n", + " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(64, eps=0.0)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(64, eps=0.0)\n", + " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(256, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): FrozenBatchNorm2d(256, eps=0.0)\n", + " )\n", + " )\n", + " (1): Bottleneck(\n", + " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(64, eps=0.0)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(64, eps=0.0)\n", + " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(256, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " (2): Bottleneck(\n", + " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(64, eps=0.0)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(64, eps=0.0)\n", + " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(256, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (layer2): Sequential(\n", + " (0): Bottleneck(\n", + " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(128, eps=0.0)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(128, eps=0.0)\n", + " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(512, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): FrozenBatchNorm2d(512, eps=0.0)\n", + " )\n", + " )\n", + " (1): Bottleneck(\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(128, eps=0.0)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(128, eps=0.0)\n", + " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(512, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " (2): Bottleneck(\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(128, eps=0.0)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(128, eps=0.0)\n", + " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(512, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " (3): Bottleneck(\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(128, eps=0.0)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(128, eps=0.0)\n", + " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(512, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (layer3): Sequential(\n", + " (0): Bottleneck(\n", + " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): FrozenBatchNorm2d(1024, eps=0.0)\n", + " )\n", + " )\n", + " (1): Bottleneck(\n", + " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " (2): Bottleneck(\n", + " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " (3): Bottleneck(\n", + " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " (4): Bottleneck(\n", + " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " (5): Bottleneck(\n", + " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(256, eps=0.0)\n", + " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (layer4): Sequential(\n", + " (0): Bottleneck(\n", + " (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(512, eps=0.0)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(512, eps=0.0)\n", + " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(2048, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): FrozenBatchNorm2d(2048, eps=0.0)\n", + " )\n", + " )\n", + " (1): Bottleneck(\n", + " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(512, eps=0.0)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(512, eps=0.0)\n", + " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(2048, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " (2): Bottleneck(\n", + " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn1): FrozenBatchNorm2d(512, eps=0.0)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): FrozenBatchNorm2d(512, eps=0.0)\n", + " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (bn3): FrozenBatchNorm2d(2048, eps=0.0)\n", + " (relu): ReLU(inplace=True)\n", + " )\n", + " )\n", + " )\n", + " (fpn): FeaturePyramidNetwork(\n", + " (inner_blocks): ModuleList(\n", + " (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (2): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (3): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (layer_blocks): ModuleList(\n", + " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (extra_blocks): LastLevelMaxPool()\n", + " )\n", + " )\n", + " (rpn): RegionProposalNetwork(\n", + " (anchor_generator): AnchorGenerator()\n", + " (head): RPNHead(\n", + " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))\n", + " (bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " (roi_heads): RoIHeads(\n", + " (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)\n", + " (box_head): TwoMLPHead(\n", + " (fc6): Linear(in_features=12544, out_features=1024, bias=True)\n", + " (fc7): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (box_predictor): FastRCNNPredictor(\n", + " (cls_score): Linear(in_features=1024, out_features=2, bias=True)\n", + " (bbox_pred): Linear(in_features=1024, out_features=8, bias=True)\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)\n", + "# get number of input features for the classifier\n", + "in_features = model.roi_heads.box_predictor.cls_score.in_features\n", + "# replace the pre-trained head with a new one\n", + "model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)\n", + "model.load_state_dict(torch.load('./models/text_detection.ali', map_location=device), False)\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "transforms_img = transforms.Compose([\n", + " # transforms.Resize(size=(244,244)),\n", + " transforms.ToTensor()\n", + "\n", + " ])\n", + "\n", + "def filter_probs(probs, iou_thresh=0.3):\n", + " num_filtered = (probs['scores']>iou_thresh).float()\n", + " keep = (num_filtered == torch.tensor(1)).nonzero().flatten()\n", + " final_probs = probs\n", + " final_probs['boxes'] = final_probs['boxes'][keep]\n", + " final_probs['scores'] = final_probs['scores'][keep]\n", + " final_probs['labels'] = final_probs['labels'][keep]\n", + " return final_probs" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def original_boxes(boxes, img_size,resized):\n", + " image_width, image_height = img_size[1], img_size[0]\n", + " boxes = torch.tensor([[\n", + " (x_min/resized)*image_width, (y_min/resized)*image_height, \\\n", + " (x_max/resized)*image_width, (y_max/resized)*image_height] \\\n", + " for (x_min, y_min, x_max, y_max) in boxes.cpu().numpy()])\n", + " return boxes" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'boxes': tensor([[ 40.8441, 27.9818, 49.3989, 42.7060],\n", + " [ 88.6962, 28.1993, 98.4423, 44.1957],\n", + " [ 48.6350, 28.3514, 57.8577, 43.7248],\n", + " [ 64.1490, 27.8397, 72.6785, 43.0498],\n", + " [ 56.0416, 28.1927, 64.7346, 43.9529],\n", + " [ 21.4208, 27.7889, 32.1110, 42.2703],\n", + " [ 99.2837, 30.3668, 108.1535, 43.8008]], dtype=torch.float64),\n", + " 'labels': tensor([1, 1, 1, 1, 1, 1, 1]),\n", + " 'scores': tensor([0.9936, 0.9910, 0.9895, 0.9882, 0.9673, 0.9524, 0.7661])}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "resize = 244\n", + "output_ori = True\n", + "image_ori = cv2.imread('./images/2.jpg')\n", + "image = cv2.resize(image_ori, (resize,resize)) if resize else image_ori\n", + "image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n", + "image_t = transforms_img(image)\n", + "model.eval()\n", + "with torch.no_grad():\n", + " probs = model([image_t])[0]\n", + "probs = filter_probs(probs, iou_thresh=0.4)\n", + "if resize and output_ori:\n", + " probs['boxes'] = original_boxes(probs['boxes'], image_ori.shape, resize)\n", + "probs" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "40.84407105993051 27.981817026607324 49.39894185300734 42.70596457309411\n", + "88.69621082993805 28.199302860947903 98.44234729204021 44.19572179825579\n", + "48.634972900640776 28.3514365211862 57.85773246014705 43.72476715338034\n", + "64.14895104580238 27.839749633288775 72.67848799658604 43.0498140053671\n", + "56.04163723304624 28.192741456578986 64.73464859509077 43.9529456466925\n", + "21.420775726193284 27.788875704906026 32.11101209921915 42.27029318887679\n", + "99.28365332181338 30.366835140791096 108.15346433295578 43.8007679923636\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "def plot_img_bbox(img, target):\n", + " # plot the image and bboxes\n", + " # Bounding boxes are defined as follows: x-min y-min width height\n", + " fig, a = plt.subplots(1,1)\n", + " fig.set_size_inches(5,5)\n", + " a.imshow(img)\n", + " for box in (target['boxes'].cpu().numpy()):\n", + " print( box[0], box[1], box[2], box[3])\n", + " x, y, width, height = box[0], box[1], box[2]-box[0], box[3]-box[1]\n", + " rect = patches.Rectangle((x, y),\n", + " width, height,\n", + " linewidth = 2,\n", + " edgecolor = 'r',\n", + " facecolor = 'none')\n", + "\n", + " # Draw the bounding box on top of the image\n", + " a.add_patch(rect)\n", + " \n", + " plt.show()\n", + "plot_img_bbox(image_ori, probs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.13 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/text_recognition.ipynb b/text_recognition.ipynb new file mode 100644 index 0000000..e9fd51f --- /dev/null +++ b/text_recognition.ipynb @@ -0,0 +1,282 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import os\n", + "import cv2\n", + "import string\n", + "import numpy as np\n", + "from PIL import Image\n", + "from glob import glob\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import torch, torchvision\n", + "from torchvision import transforms\n", + "\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "LABEL = string.digits+string.ascii_uppercase\n", + "label_dict = {idx : label for idx, label in enumerate(LABEL)}\n", + "num_classes = len(label_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Net(\n", + " (conv1): Sequential(\n", + " (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (7): Dropout(p=0.25, inplace=False)\n", + " )\n", + " (conv2): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (4): ReLU()\n", + " (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (7): Dropout(p=0.25, inplace=False)\n", + " )\n", + " (conv3): Sequential(\n", + " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (4): Dropout(p=0.25, inplace=False)\n", + " )\n", + " (fc): Sequential(\n", + " (0): Linear(in_features=128, out_features=36, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " \n", + " self.conv1 = nn.Sequential(\n", + " nn.Conv2d(3, 32, 3, padding=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(32),\n", + " nn.Conv2d(32, 32, 3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(32),\n", + " nn.MaxPool2d(2, 2),\n", + " nn.Dropout(0.25)\n", + " )\n", + " \n", + " self.conv2 = nn.Sequential(\n", + " nn.Conv2d(32, 64, 3, padding=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(64),\n", + " nn.Conv2d(64, 64, 3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(64),\n", + " nn.MaxPool2d(2, 2),\n", + " nn.Dropout(0.25)\n", + " )\n", + " \n", + " self.conv3 = nn.Sequential(\n", + " nn.Conv2d(64, 128, 3, padding=1),\n", + " nn.ReLU(),\n", + " nn.BatchNorm2d(128),\n", + " nn.MaxPool2d(2, 2),\n", + " nn.Dropout(0.25)\n", + " )\n", + " \n", + " self.fc = nn.Sequential(\n", + " nn.Linear(128, num_classes),\n", + " )\n", + " \n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.conv2(x)\n", + " x = self.conv3(x)\n", + " \n", + " x = x.view(x.size(0), -1)\n", + " return self.fc(x)\n", + "\n", + "transforms_img = transforms.Compose([\n", + " transforms.Resize(size=(31,31)),\n", + " transforms.CenterCrop(size=31),\n", + " transforms.ToTensor(),\n", + " transforms.Grayscale(3),\n", + " transforms.Normalize(mean=(0.5,), std=(0.5,))\n", + "\n", + " ])\n", + "\n", + "model = Net()\n", + "model.load_state_dict(torch.load('./models/text_recognition.ali', map_location=device))\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[1.]]) tensor([[1]])\n", + "1\n", + "1 1.0 ./images/1_10043.jpg\n", + "tensor([[1.]]) tensor([[1]])\n", + "1\n", + "1 1.0 ./images/1_10060.jpg\n", + "tensor([[1.]]) tensor([[1]])\n", + "1\n", + "1 1.0 ./images/1_10059.jpg\n", + "tensor([[1.]]) tensor([[1]])\n", + "1\n", + "1 1.0 ./images/1_10104.jpg\n", + "tensor([[0.7827]]) tensor([[5]])\n", + "5\n", + "5 0.78 ./images/12022041104595530.jpg\n", + "tensor([[1.]]) tensor([[1]])\n", + "1\n", + "1 1.0 ./images/1_10029.jpg\n", + "tensor([[0.6902]]) tensor([[3]])\n", + "3\n", + "3 0.69 ./images/1.jpg\n" + ] + } + ], + "source": [ + "for i in glob('./images/1*.jpg'):\n", + " image = cv2.imread(i)\n", + " image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n", + " image = transforms_img(image)\n", + " image = image.view(1, 3, 31, 31).cuda() if torch.cuda.is_available() else image.view(1, 3, 31, 31)\n", + "\n", + " with torch.no_grad():\n", + " model.eval()\n", + " output = model(image)\n", + " \n", + " output = torch.nn.functional.log_softmax(output, dim=1)\n", + " output = torch.exp(output)\n", + " prob, top_class = torch.topk(output, k=1, dim=1) \n", + " print(prob, top_class)\n", + " res_label = label_dict[top_class.cpu().numpy()[0][0]]\n", + " print(LABEL[top_class.cpu().numpy()[0][0]])\n", + " res_prob = round((prob.cpu().numpy()[0][0]), 2)\n", + " print(res_label, res_prob, i)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 3, 31, 31])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image.size()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAB8AAAAfCAIAAACQzIFuAAAChklEQVR4nJ1WUZbjIAwzNpDe/xY94QQCZD/UqA5JZt5bf7SUgizLBhPe77eIiMgYQ0T2w+Jhqoq/VBU/sbi1Ng7zCH7wWQ0IjGFmRnRO0tkYI8boHQDRD7A+Ap2bMcBnCEFVVTWEgGUhhH3fgdJ7771v24ZQWmv7vouzaGYAMrOcMwZADyEgIBhJkCNAa621VrjHZ2ut9y4iMaUE9Jwz0GHYOXHxEsEH/WEAdL8sikg6DMTNDOHDgZmZGTThTqhE6VQ1pYQ1sC86NsMzgiqllFJ+fn6QQK9PCAEMWDzwF0IAjy86BYXRc2utlLKua2tNnbE0QcvnwFfLSRlwIXegI12gNmXVF6uH885O6KpqZiy1J0bXk8VM0tkYY1Ym3JkX0TvD5BTQVK8iElmnCNarzGRyNSoE1DCJVMGfR/9wxxdoLsuSUuKp27ZtqjDsZ70CFIhIFT190uPvh6v5eLGNCfC4XAAlvug+3bszP48cTHBPx9hb5FJKyaTv++65/A/6L8r4E8SywTbqJmfp/ka/XXdbr3/Sv+d+xb06YC1euRPhiz6Fed3jZzxxX+Yz999Du3VA1n6SxYqEfc4qT9OTAhDhyRMHvfdrJuLtBnmuhKt0vqhmdFywtdbpOvOXOK5MbEDgbGG8jnBEeu/tsC86ttVaU0posGaWUlqWxcxwy49zW8g5L8sSY3y9XujDIoKusK5rrVVYM+I6zvSMaa3lnD00O990iV4Leq53X2eqmnOeksG3DZsGoufLCYGelPGdE2qKa0NT772qPI7XEgcfdAhE7vijlIKn0nThkDvmn9Cpz4yuqrVWvCamKvSXmuc+zi9Wr/6pZkgT6MSadCc6+hGxphSKyD+iMq/r9imv4QAAAABJRU5ErkJggg==", + "text/plain": [ + "" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image = cv2.imread('/Users/alimustofa/Halotec/Source Code/research/ocr/from_scratch/images/C_13549.jpg')\n", + "image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n", + "image = transforms_img(image)\n", + "image = transforms.ToPILImage()(image)\n", + "image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.13 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}