From fb85c082073f560b0848ad28fe291c710a71c17f Mon Sep 17 00:00:00 2001 From: Olivier Dulcy Date: Mon, 8 Jan 2024 17:41:38 +0100 Subject: [PATCH] DB loss fix from Felix --- .../differentiable_binarization/base.py | 38 ++++---- .../differentiable_binarization/pytorch.py | 97 ++++++++++--------- .../differentiable_binarization/tensorflow.py | 84 ++++++++-------- 3 files changed, 116 insertions(+), 103 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index acb0bb314..0d261d299 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -201,9 +201,10 @@ def compute_distance( square_dist_2 = np.square(xs - b[0]) + np.square(ys - b[1]) square_dist = np.square(a[0] - b[0]) + np.square(a[1] - b[1]) cosin = (square_dist - square_dist_1 - square_dist_2) / (2 * np.sqrt(square_dist_1 * square_dist_2) + eps) + cosin = np.clip(cosin, -1.0, 1.0) square_sin = 1 - np.square(cosin) square_sin = np.nan_to_num(square_sin) - result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist) + result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist + eps) result[cosin < 0] = np.sqrt(np.fmin(square_dist_1, square_dist_2))[cosin < 0] return result @@ -265,7 +266,10 @@ def draw_thresh_map( # Fill the canvas with the distances computed inside the valid padded polygon canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax( - 1 - distance_map[ymin_valid - ymin : ymax_valid - ymin + 1, xmin_valid - xmin : xmax_valid - xmin + 1], + 1 + - distance_map[ + ymin_valid - ymin : ymax_valid - ymax + height, xmin_valid - xmin : xmax_valid - xmax + width + ], canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1], ) @@ -274,7 +278,7 @@ def draw_thresh_map( def build_target( self, target: List[Dict[str, np.ndarray]], - output_shape: Tuple[int, int, int, int], + output_shape: Tuple[int, int, int], channels_last: bool = True, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): @@ -284,23 +288,24 @@ def build_target( input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32 + h: int + w: int if channels_last: - h, w = output_shape[1:-1] - target_shape = (output_shape[0], output_shape[-1], h, w) # (Batch_size, num_classes, h, w) + h, w, num_classes = output_shape else: - h, w = output_shape[-2:] - target_shape = output_shape # (Batch_size, num_classes, h, w) + num_classes, h, w = output_shape + target_shape = (len(target), num_classes, h, w) + seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32) - thresh_mask: np.ndarray = np.ones(target_shape, dtype=np.uint8) + thresh_mask: np.ndarray = np.zeros(target_shape, dtype=np.uint8) for idx, tgt in enumerate(target): for class_idx, _tgt in enumerate(tgt.values()): # Draw each polygon on gt if _tgt.shape[0] == 0: # Empty image, full masked - # seg_mask[idx, :, :, class_idx] = False seg_mask[idx, class_idx] = False # Absolute bounding boxes @@ -326,10 +331,9 @@ def build_target( ) boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) - for box, box_size, poly in zip(abs_boxes, boxes_size, polys): + for poly, box, box_size in zip(polys, abs_boxes, boxes_size): # Mask boxes that are too small if box_size < self.min_size_box: - # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue @@ -339,19 +343,17 @@ def build_target( subject = [tuple(coor) for coor in poly] padding = pyclipper.PyclipperOffset() padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) - shrinked = padding.Execute(-distance) + shrunken = padding.Execute(-distance) # Draw polygon on gt if it is valid - if len(shrinked) == 0: - # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + if len(shrunken) == 0: seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - shrinked = np.array(shrinked[0]).reshape(-1, 2) - if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid: - # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + shrunken = np.array(shrunken[0]).reshape(-1, 2) + if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - cv2.fillPoly(seg_target[idx, class_idx], [shrinked.astype(np.int32)], 1.0) # type: ignore[call-overload] + cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] # Draw on both thresh map and thresh mask poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map( diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 99adcd0e0..02996c780 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -57,24 +57,28 @@ def __init__( conv_layer = DeformConv2d if deform_conv else nn.Conv2d - self.in_branches = nn.ModuleList([ - nn.Sequential( - conv_layer(chans, out_channels, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - for idx, chans in enumerate(in_channels) - ]) + self.in_branches = nn.ModuleList( + [ + nn.Sequential( + conv_layer(chans, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + for idx, chans in enumerate(in_channels) + ] + ) self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.out_branches = nn.ModuleList([ - nn.Sequential( - conv_layer(out_channels, out_chans, 3, padding=1, bias=False), - nn.BatchNorm2d(out_chans), - nn.ReLU(inplace=True), - nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True), - ) - for idx, chans in enumerate(in_channels) - ]) + self.out_branches = nn.ModuleList( + [ + nn.Sequential( + conv_layer(out_channels, out_chans, 3, padding=1, bias=False), + nn.BatchNorm2d(out_chans), + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True), + ) + for idx, chans in enumerate(in_channels) + ] + ) def forward(self, x: List[torch.Tensor]) -> torch.Tensor: if len(x) != len(self.out_branches): @@ -213,7 +217,15 @@ def forward( return out - def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor: + def compute_loss( + self, + out_map: torch.Tensor, + thresh_map: torch.Tensor, + target: List[np.ndarray], + gamma: float = 2.0, + alpha: float = 0.5, + eps: float = 1e-8, + ) -> torch.Tensor: """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes and a list of masks for each image. From there it computes the loss with the model output @@ -222,6 +234,9 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: out_map: output feature map of the model of shape (N, C, H, W) thresh_map: threshold map of shape (N, C, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry + gamma: modulating factor in the focal loss formula + alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss Returns: ------- @@ -230,48 +245,40 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: prob_map = torch.sigmoid(out_map) thresh_map = torch.sigmoid(thresh_map) - targets = self.build_target(target, prob_map.shape, False) # type: ignore[arg-type] + targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) - # Compute balanced BCE loss for proba_map - bce_scale = 5.0 - balanced_bce_loss = torch.zeros(1, device=out_map.device) + focal_loss = torch.zeros(1, device=out_map.device) dice_loss = torch.zeros(1, device=out_map.device) l1_loss = torch.zeros(1, device=out_map.device) if torch.any(seg_mask): - bce_loss = F.binary_cross_entropy_with_logits( - out_map, - seg_target, - reduction="none", - )[seg_mask] - - neg_target = 1 - seg_target[seg_mask] - positive_count = seg_target[seg_mask].sum() - negative_count = torch.minimum(neg_target.sum(), 3.0 * positive_count) - negative_loss = bce_loss * neg_target - negative_loss = negative_loss.sort().values[-int(negative_count.item()) :] - sum_losses = torch.sum(bce_loss * seg_target[seg_mask]) + torch.sum(negative_loss) - balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6) - - # Compute dice loss for approxbin_map - bin_map = 1 / (1 + torch.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask]))) - - bce_min = bce_loss.min() - weights = (bce_loss - bce_min) / (bce_loss.max() - bce_min) + 1.0 - inter = torch.sum(bin_map * seg_target[seg_mask] * weights) - union = torch.sum(bin_map) + torch.sum(seg_target[seg_mask]) + 1e-8 # type: ignore[call-overload] - dice_loss = 1 - 2.0 * inter / union + # Focal loss + bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") + + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target) + alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target) + # Unreduced version + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = (seg_mask * focal_loss).sum() / seg_mask.sum() + + # Dice loss + inter = (seg_mask * prob_map * seg_target).sum() + cardinality = (seg_mask * (prob_map + seg_target)).sum() + dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map l1_scale = 10.0 if torch.any(thresh_mask): l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask])) - return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss # type: ignore[return-value] + return l1_scale * l1_loss + focal_loss + dice_loss # type: ignore[return-value] def _dbnet( diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 21943f947..e6bf5024b 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -144,20 +144,24 @@ def __init__( _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape] output_shape = tuple(self.fpn(_inputs).shape) - self.probability_head = keras.Sequential([ - *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), - layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), - layers.BatchNormalization(), - layers.Activation("relu"), - layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), - ]) - self.threshold_head = keras.Sequential([ - *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), - layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), - layers.BatchNormalization(), - layers.Activation("relu"), - layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), - ]) + self.probability_head = keras.Sequential( + [ + *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), + layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), + layers.BatchNormalization(), + layers.Activation("relu"), + layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), + ] + ) + self.threshold_head = keras.Sequential( + [ + *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), + layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), + layers.BatchNormalization(), + layers.Activation("relu"), + layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), + ] + ) self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) @@ -166,6 +170,9 @@ def compute_loss( out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[Dict[str, np.ndarray]], + gamma: float = 2.0, + alpha: float = 0.5, + eps: float = 1e-8, ) -> tf.Tensor: """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes and a list of masks for each image. From there it computes the loss with the model output @@ -175,6 +182,9 @@ def compute_loss( out_map: output feature map of the model of shape (N, H, W, C) thresh_map: threshold map of shape (N, H, W, C) target: list of dictionary where each dict has a `boxes` and a `flags` entry + gamma: modulating factor in the focal loss formula + alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss Returns: ------- @@ -183,36 +193,30 @@ def compute_loss( prob_map = tf.math.sigmoid(out_map) thresh_map = tf.math.sigmoid(thresh_map) - seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape, True) + seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True) seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) + seg_mask = tf.cast(seg_mask, tf.float32) thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype) thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool) - # Compute balanced BCE loss for proba_map - bce_scale = 5.0 - bce_loss = tf.keras.losses.binary_crossentropy( - seg_target[..., None], - out_map[..., None], - from_logits=True, - )[seg_mask] - - neg_target = 1 - seg_target[seg_mask] - positive_count = tf.math.reduce_sum(seg_target[seg_mask]) - negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count]) - negative_loss = bce_loss * neg_target - negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32)) - sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss) - balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6) - - # Compute dice loss for approxbin_map - bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask]))) - - bce_min = tf.math.reduce_min(bce_loss) - weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0 - inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights) - union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8 - dice_loss = 1 - 2.0 * inter / union + bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) + + # Focal loss + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + # Convert logits to prob, compute gamma factor + p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map)) + alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha) + # Unreduced loss + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3)) + + # Dice loss + inter = tf.math.reduce_sum(seg_mask * prob_map * seg_target, (0, 1, 2, 3)) + cardinality = tf.math.reduce_sum((prob_map + seg_target), (0, 1, 2, 3)) + dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Compute l1 loss for thresh_map l1_scale = 10.0 @@ -221,7 +225,7 @@ def compute_loss( else: l1_loss = tf.constant(0.0) - return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss + return l1_scale * l1_loss + focal_loss + dice_loss def call( self,