Mixup implemention in progress

* initial impl w/ label smoothing converging, but needs more testing
This commit is contained in:
Ross Wightman 2019-05-13 19:05:40 -07:00
parent c3fbdd4655
commit fee607edf6
5 changed files with 45 additions and 5 deletions

View File

@ -29,7 +29,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
* PNasNet (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch)) * PNasNet (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch))
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene) * DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107 * DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
* My generic MobileNet (GenMobileNet) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable, InvertedResidual, etc blocks * Generic MobileNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
* MNASNet B1, A1 (Squeeze-Excite), and Small * MNASNet B1, A1 (Squeeze-Excite), and Small
* MobileNet-V1 * MobileNet-V1
* MobileNet-V2 * MobileNet-V2
@ -49,7 +49,8 @@ Several (less common) features that I often utilize in my projects are included.
* PyTorch w/ single GPU single process (AMP optional) * PyTorch w/ single GPU single process (AMP optional)
* A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights. * A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights.
* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs) * A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Smoothed Softmax, etc) * Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc)
* Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing
* An inference script that dumps output to CSV is provided as an example * An inference script that dumps output to CSV is provided as an example
### Custom Weights ### Custom Weights

View File

@ -1 +1 @@
from loss.cross_entropy import LabelSmoothingCrossEntropy from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy

View File

@ -1,3 +1,4 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -24,3 +25,12 @@ class LabelSmoothingCrossEntropy(nn.Module):
loss = self.confidence * nll_loss + self.smoothing * smooth_loss loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean() return loss.mean()
class SparseLabelCrossEntropy(nn.Module):
def __init__(self):
super(SparseLabelCrossEntropy, self).__init__()
def forward(self, x, target):
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
return loss.mean()

View File

@ -13,7 +13,7 @@ except ImportError:
from data import * from data import *
from models import create_model, resume_checkpoint from models import create_model, resume_checkpoint
from utils import * from utils import *
from loss import LabelSmoothingCrossEntropy from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
from optim import create_optimizer from optim import create_optimizer
from scheduler import create_scheduler from scheduler import create_scheduler
@ -79,6 +79,10 @@ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)') help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001, parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)') help='weight decay (default: 0.0001)')
parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1, parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)') help='label smoothing (default: 0.1)')
parser.add_argument('--bn-tf', action='store_true', default=False, parser.add_argument('--bn-tf', action='store_true', default=False,
@ -246,7 +250,11 @@ def main():
distributed=args.distributed, distributed=args.distributed,
) )
if args.smoothing: if args.mixup > 0.:
# smoothing is handled with mixup label transform
train_loss_fn = SparseLabelCrossEntropy().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda()
else: else:
@ -314,6 +322,13 @@ def train_epoch(
last_batch = batch_idx == last_idx last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end) data_time_m.update(time.time() - end)
if args.mixup > 0.:
lam = 1.
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
lam = np.random.beta(args.mixup, args.mixup)
input.mul_(lam).add_(1 - lam, input.flip(0))
target = mixup_target(target, args.num_classes, lam, args.smoothing)
output = model(input) output = model(input)
loss = loss_fn(output, target) loss = loss_fn(output, target)

View File

@ -5,6 +5,7 @@ import shutil
import glob import glob
import csv import csv
import operator import operator
import numpy as np
from collections import OrderedDict from collections import OrderedDict
@ -139,6 +140,19 @@ def accuracy(output, target, topk=(1,)):
return res return res
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
x = x.long().view(-1, 1)
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
return lam*y1 + (1. - lam)*y2
def get_outdir(path, *paths, inc=False): def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths) outdir = os.path.join(path, *paths)
if not os.path.exists(outdir): if not os.path.exists(outdir):