Mixup implemention in progress
* initial impl w/ label smoothing converging, but needs more testingpull/1/head
parent
c3fbdd4655
commit
fee607edf6
|
@ -29,7 +29,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
|
|||
* PNasNet (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch))
|
||||
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
|
||||
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
|
||||
* My generic MobileNet (GenMobileNet) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable, InvertedResidual, etc blocks
|
||||
* Generic MobileNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
|
||||
* MNASNet B1, A1 (Squeeze-Excite), and Small
|
||||
* MobileNet-V1
|
||||
* MobileNet-V2
|
||||
|
@ -49,7 +49,8 @@ Several (less common) features that I often utilize in my projects are included.
|
|||
* PyTorch w/ single GPU single process (AMP optional)
|
||||
* A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights.
|
||||
* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
|
||||
* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Smoothed Softmax, etc)
|
||||
* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc)
|
||||
* Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing
|
||||
* An inference script that dumps output to CSV is provided as an example
|
||||
|
||||
### Custom Weights
|
||||
|
|
|
@ -1 +1 @@
|
|||
from loss.cross_entropy import LabelSmoothingCrossEntropy
|
||||
from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
|
|
@ -1,3 +1,4 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
@ -24,3 +25,12 @@ class LabelSmoothingCrossEntropy(nn.Module):
|
|||
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class SparseLabelCrossEntropy(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(SparseLabelCrossEntropy, self).__init__()
|
||||
|
||||
def forward(self, x, target):
|
||||
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
|
||||
return loss.mean()
|
||||
|
|
19
train.py
19
train.py
|
@ -13,7 +13,7 @@ except ImportError:
|
|||
from data import *
|
||||
from models import create_model, resume_checkpoint
|
||||
from utils import *
|
||||
from loss import LabelSmoothingCrossEntropy
|
||||
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
|
||||
from optim import create_optimizer
|
||||
from scheduler import create_scheduler
|
||||
|
||||
|
@ -79,6 +79,10 @@ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
|||
help='SGD momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
||||
help='weight decay (default: 0.0001)')
|
||||
parser.add_argument('--mixup', type=float, default=0.0,
|
||||
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
|
||||
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
||||
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='label smoothing (default: 0.1)')
|
||||
parser.add_argument('--bn-tf', action='store_true', default=False,
|
||||
|
@ -246,7 +250,11 @@ def main():
|
|||
distributed=args.distributed,
|
||||
)
|
||||
|
||||
if args.smoothing:
|
||||
if args.mixup > 0.:
|
||||
# smoothing is handled with mixup label transform
|
||||
train_loss_fn = SparseLabelCrossEntropy().cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
elif args.smoothing:
|
||||
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
else:
|
||||
|
@ -314,6 +322,13 @@ def train_epoch(
|
|||
last_batch = batch_idx == last_idx
|
||||
data_time_m.update(time.time() - end)
|
||||
|
||||
if args.mixup > 0.:
|
||||
lam = 1.
|
||||
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
|
||||
lam = np.random.beta(args.mixup, args.mixup)
|
||||
input.mul_(lam).add_(1 - lam, input.flip(0))
|
||||
target = mixup_target(target, args.num_classes, lam, args.smoothing)
|
||||
|
||||
output = model(input)
|
||||
|
||||
loss = loss_fn(output, target)
|
||||
|
|
14
utils.py
14
utils.py
|
@ -5,6 +5,7 @@ import shutil
|
|||
import glob
|
||||
import csv
|
||||
import operator
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
|
@ -139,6 +140,19 @@ def accuracy(output, target, topk=(1,)):
|
|||
return res
|
||||
|
||||
|
||||
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
|
||||
x = x.long().view(-1, 1)
|
||||
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
||||
|
||||
|
||||
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
|
||||
off_value = smoothing / num_classes
|
||||
on_value = 1. - smoothing + off_value
|
||||
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
|
||||
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
|
||||
return lam*y1 + (1. - lam)*y2
|
||||
|
||||
|
||||
def get_outdir(path, *paths, inc=False):
|
||||
outdir = os.path.join(path, *paths)
|
||||
if not os.path.exists(outdir):
|
||||
|
|
Loading…
Reference in New Issue