This repository has been archived by the owner on Dec 5, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Arjun Variar
authored and
Arjun Variar
committed
Dec 11, 2019
1 parent
e6ec002
commit 4925e83
Showing
2 changed files
with
30 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,8 @@ | ||
# Label Smoothed Aggregation CrossEntropyLoss | ||
Label smoothed Aggregation cross entropy loss for generalisation in sequence to sequence tasks. | ||
|
||
This is useful for generalization in sequence to sequence tasks, helps lower the ECE loss. | ||
|
||
For more information please refer too, | ||
(When Does Label Smoothing Help?)[https://arxiv.org/abs/1906.02629] | ||
(Aggregation Cross-Entropy for Sequence Recognition)[https://arxiv.org/abs/1904.08364] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class ACELabelSmoothingLoss(nn.Module): | ||
|
||
def __init__(self, alpha=0.1): | ||
super().__init__() | ||
self.alpha = alpha | ||
|
||
def forward(self, logits, targets, input_lengths, target_lengths): | ||
T_, bs, class_size = logits.size() | ||
tagets_split = list(torch.split(targets, target_lengths.tolist())) | ||
targets_padded = torch.nn.utils.rnn.pad_sequence(tagets_split, batch_first=True, padding_value=0) | ||
targets_padded = F.one_hot(targets_padded.long(), num_classes=class_size) # batch, seq, class | ||
targets_padded = (targets_padded * (1-self.alpha)) + (self.alpha/class_size) | ||
targets_padded = torch.sum(targets_padded, 1).float().cuda() # sum across seq, to get batch * class | ||
targets_padded[:,0] = T_ - target_lengths | ||
probs = torch.softmax(logits, dim=2) # softmax on class | ||
probs = torch.sum(probs, 0) # sum across seq, to get batch * class | ||
probs = probs/T_ | ||
targets_padded = targets_padded/T_ | ||
targets_padded = F.normalize(targets_padded, p=1, dim=1) | ||
return F.kl_div(torch.log(probs), targets_padded, reduction='batchmean') |