Label Smoothing in PyTorch

Jared Nielsen picture Jared Nielsen · Apr 15, 2019 · Viewed 16k times · Source

I'm building a ResNet-18 classification model for the Stanford Cars dataset using transfer learning. I would like to implement label smoothing to penalize overconfident predictions and improve generalization.

TensorFlow has a simple keyword argument in CrossEntropyLoss. Has anyone built a similar function for PyTorch that I could plug-and-play with?

Answer

Shital Shah picture Shital Shah · Dec 10, 2019

I've been looking for options that derives from _Loss like other loss classes in PyTorch and respects basic parameters such as reduction. Unfortunately I can't find straight forward replacement so ended up writing my own. I haven't fully tested this yet, however:

import torch
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F

class SmoothCrossEntropyLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction='mean', smoothing=0.0):
        super().__init__(weight=weight, reduction=reduction)
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction

    @staticmethod
    def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0):
        assert 0 <= smoothing < 1
        with torch.no_grad():
            targets = torch.empty(size=(targets.size(0), n_classes),
                    device=targets.device) \
                .fill_(smoothing /(n_classes-1)) \
                .scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
        return targets

    def forward(self, inputs, targets):
        targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
            self.smoothing)
        lsm = F.log_softmax(inputs, -1)

        if self.weight is not None:
            lsm = lsm * self.weight.unsqueeze(0)

        loss = -(targets * lsm).sum(-1)

        if  self.reduction == 'sum':
            loss = loss.sum()
        elif  self.reduction == 'mean':
            loss = loss.mean()

        return loss

Other options: