From 4ff5ce282d40a9844ed7450e2a19e7c333ff3d9d Mon Sep 17 00:00:00 2001 From: Olivier Dulcy <106678676+odulcy-mindee@users.noreply.github.com> Date: Tue, 13 Feb 2024 11:15:22 +0100 Subject: [PATCH] feat: :sparkles: PT db_resnet50 checkpoint (#1465) --- .../differentiable_binarization/pytorch.py | 40 ++++++++++--------- tests/pytorch/test_models_zoo_pt.py | 4 +- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 82ced7da9..17686bb28 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -27,7 +27,7 @@ "input_shape": (3, 1024, 1024), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.3.1/db_resnet50-ac60cadc.pt&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-79bd7d70.pt&src=0", }, "db_resnet34": { "input_shape": (3, 1024, 1024), @@ -57,24 +57,28 @@ def __init__( conv_layer = DeformConv2d if deform_conv else nn.Conv2d - self.in_branches = nn.ModuleList([ - nn.Sequential( - conv_layer(chans, out_channels, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - for idx, chans in enumerate(in_channels) - ]) + self.in_branches = nn.ModuleList( + [ + nn.Sequential( + conv_layer(chans, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + for idx, chans in enumerate(in_channels) + ] + ) self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.out_branches = nn.ModuleList([ - nn.Sequential( - conv_layer(out_channels, out_chans, 3, padding=1, bias=False), - nn.BatchNorm2d(out_chans), - nn.ReLU(inplace=True), - nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True), - ) - for idx, chans in enumerate(in_channels) - ]) + self.out_branches = nn.ModuleList( + [ + nn.Sequential( + conv_layer(out_channels, out_chans, 3, padding=1, bias=False), + nn.BatchNorm2d(out_chans), + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True), + ) + for idx, chans in enumerate(in_channels) + ] + ) def forward(self, x: List[torch.Tensor]) -> torch.Tensor: if len(x) != len(self.out_branches): diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index 3c6267ab7..5bcd10ee6 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -222,9 +222,9 @@ def test_trained_kie_predictor(mock_payslip): geometry_mr = np.array([[0.1083984375, 0.0634765625], [0.1494140625, 0.0859375]]) assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][0].geometry), geometry_mr, rtol=0.05) - assert out.pages[0].predictions[CLASS_NAME][6].value == "revised" + assert out.pages[0].predictions[CLASS_NAME][4].value == "revised" geometry_revised = np.array([[0.7548828125, 0.126953125], [0.8388671875, 0.1484375]]) - assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][6].geometry), geometry_revised, rtol=0.05) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][4].geometry), geometry_revised, rtol=0.05) det_predictor = detection_predictor( "db_resnet50",