Skip to content
This repository has been archived by the owner on Dec 5, 2021. It is now read-only.

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun Variar authored and Arjun Variar committed Dec 11, 2019
1 parent e6ec002 commit 4925e83
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
# Label Smoothed Aggregation CrossEntropyLoss
Label smoothed Aggregation cross entropy loss for generalisation in sequence to sequence tasks.

This is useful for generalization in sequence to sequence tasks, helps lower the ECE loss.

For more information please refer too,
(When Does Label Smoothing Help?)[https://arxiv.org/abs/1906.02629]
(Aggregation Cross-Entropy for Sequence Recognition)[https://arxiv.org/abs/1904.08364]
24 changes: 24 additions & 0 deletions lsaceloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class ACELabelSmoothingLoss(nn.Module):

def __init__(self, alpha=0.1):
super().__init__()
self.alpha = alpha

def forward(self, logits, targets, input_lengths, target_lengths):
T_, bs, class_size = logits.size()
tagets_split = list(torch.split(targets, target_lengths.tolist()))
targets_padded = torch.nn.utils.rnn.pad_sequence(tagets_split, batch_first=True, padding_value=0)
targets_padded = F.one_hot(targets_padded.long(), num_classes=class_size) # batch, seq, class
targets_padded = (targets_padded * (1-self.alpha)) + (self.alpha/class_size)
targets_padded = torch.sum(targets_padded, 1).float().cuda() # sum across seq, to get batch * class
targets_padded[:,0] = T_ - target_lengths
probs = torch.softmax(logits, dim=2) # softmax on class
probs = torch.sum(probs, 0) # sum across seq, to get batch * class
probs = probs/T_
targets_padded = targets_padded/T_
targets_padded = F.normalize(targets_padded, p=1, dim=1)
return F.kl_div(torch.log(probs), targets_padded, reduction='batchmean')

0 comments on commit 4925e83

Please sign in to comment.