Skip to content

Commit

Permalink
[FIX] Dice loss computation in both backends (mindee#1442)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jan 26, 2024
1 parent abf0571 commit f5445ef
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 20 deletions.
19 changes: 12 additions & 7 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,18 @@ def compute_loss(
# Unreduced version
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
# Class reduced
focal_loss = (seg_mask * focal_loss).sum() / seg_mask.sum()

# Compute dice loss for approx binary_map
binary_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
inter = (seg_mask * binary_map * seg_target).sum() # type: ignore[attr-defined]
cardinality = (seg_mask * (binary_map + seg_target)).sum() # type: ignore[attr-defined]
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)
focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3))

# Compute dice loss for each class or for approx binary_map
if len(self.class_names) > 1:
dice_map = torch.softmax(out_map, dim=1)
else:
# compute binary map instead
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
# Class reduced
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
dice_loss = (1 - 2 * inter / (cardinality + eps)).mean()

# Compute l1 loss for thresh_map
if torch.any(thresh_mask):
Expand Down
15 changes: 10 additions & 5 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,16 @@ def compute_loss(
# Class reduced
focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))

# Compute dice loss for approx binary_map
binary_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
inter = tf.reduce_sum(seg_mask * binary_map * seg_target, (0, 1, 2, 3))
cardinality = tf.reduce_sum((binary_map + seg_target), (0, 1, 2, 3))
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)
# Compute dice loss for each class or for approx binary_map
if len(self.class_names) > 1:
dice_map = tf.nn.softmax(out_map, axis=-1)
else:
# compute binary map instead
dice_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
# Class-reduced dice loss
inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))

# Compute l1 loss for thresh_map
if tf.reduce_any(thresh_mask):
Expand Down
10 changes: 6 additions & 4 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,12 @@ def compute_loss(
# Class reduced
focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3))

# Dice loss
inter = (seg_mask * proba_map * seg_target).sum((0, 1, 2, 3))
cardinality = (seg_mask * (proba_map + seg_target)).sum((0, 1, 2, 3))
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)
# Compute dice loss for each class
dice_map = torch.softmax(out_map, dim=1) if len(self.class_names) > 1 else proba_map
# Class reduced
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
dice_loss = (1 - 2 * inter / (cardinality + eps)).mean()

# Return the full loss (equal sum of focal loss and dice loss)
return focal_loss + dice_loss
Expand Down
10 changes: 6 additions & 4 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,12 @@ def compute_loss(
# Class reduced
focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))

# Dice loss
inter = tf.math.reduce_sum(seg_mask * proba_map * seg_target, (0, 1, 2, 3))
cardinality = tf.math.reduce_sum((proba_map + seg_target), (0, 1, 2, 3))
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)
# Compute dice loss for each class
dice_map = tf.nn.softmax(out_map, axis=-1) if len(self.class_names) > 1 else proba_map
# Class-reduced dice loss
inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))

return focal_loss + dice_loss

Expand Down

0 comments on commit f5445ef

Please sign in to comment.