From 4814b8d602278de0e48351aba3221a6dff0f3630 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Mon, 26 Aug 2019 11:31:37 +0800 Subject: [PATCH] 1. add label smooth 2. update experiment --- README.md | 5 +++-- configs/softmax_triplet.yml | 2 +- datasets | 1 - modeling/backbones/resnet.py | 5 +++-- modeling/losses/label_smooth.py | 37 +++++++++++++++++++++++++++++++++ modeling/losses/loss.py | 22 +++++++------------- scripts/train_duke.sh | 22 +++++++++++++++++--- scripts/train_market.sh | 30 +++++++++++++++++++------- tools/train.py | 4 +--- 9 files changed, 93 insertions(+), 35 deletions(-) delete mode 120000 datasets create mode 100644 modeling/losses/label_smooth.py diff --git a/README.md b/README.md index 0338ba4..a6bcb54 100644 --- a/README.md +++ b/README.md @@ -78,11 +78,12 @@ python3 tools/test.py DATASET.TEST_NAMES 'duke' \ | size=(256, 128) batch_size=64 (16 id x 4 imgs) | | | | | | :------: | :-----: | :-----: | :--: | :---: | | softmax? | ✔︎ | ✔︎ | ✔︎ | ✔︎ | +| label smooth? | | | ✔︎ | ✔︎ | | triplet? | | ✔︎ | ✔︎ | ✔︎ | | ibn? | | | ✔︎ | ✔︎ | | gcnet? | | | | ✔︎ | -| Market1501 | 93.4 (82.9) | 94.2 (86.1) | 94.9 (86.4)| 94.9 (87.6) | -| DukeMTMC-reid | 84.7 (72.7) | 87.3 (76.0) | 87.9 (77.1)| 89.0 (78.8) | +| Market1501 | 93.4 (82.9) | 94.2 (86.1) | 95.4 (87.9) | 95.2 (88.7) | +| DukeMTMC-reid | 84.7 (72.7) | 87.3 (76.0) | 89.5 (79.7) | 90.0 (80.2) | | CUHK03 | | | | | diff --git a/configs/softmax_triplet.yml b/configs/softmax_triplet.yml index d21512d..71ecdd8 100644 --- a/configs/softmax_triplet.yml +++ b/configs/softmax_triplet.yml @@ -19,7 +19,7 @@ DATALOADER: SOLVER: OPT: 'adam' - LOSSTYPE: ('softmax', 'triplet') + LOSSTYPE: ('softmax_smooth', 'triplet') MAX_EPOCHS: 150 BASE_LR: 0.00035 WEIGHT_DECAY: 0.0005 diff --git a/datasets b/datasets deleted file mode 120000 index 30f9f15..0000000 --- a/datasets +++ /dev/null @@ -1 +0,0 @@ -../datasets/ \ No newline at end of file diff --git a/modeling/backbones/resnet.py b/modeling/backbones/resnet.py index fa1a4d6..3ebf9d3 100644 --- a/modeling/backbones/resnet.py +++ b/modeling/backbones/resnet.py @@ -35,7 +35,7 @@ __all__ = ['ResNet'] class IBN(nn.Module): def __init__(self, planes): super(IBN, self).__init__() - half1 = int(planes/8) + half1 = int(planes/2) self.half = half1 half2 = planes - half1 self.IN = nn.InstanceNorm2d(half1, affine=True) @@ -44,7 +44,8 @@ class IBN(nn.Module): def forward(self, x): split = torch.split(x, self.half, 1) out1 = self.IN(split[0].contiguous()) - out2 = self.BN(torch.cat(split[1:], dim=1).contiguous()) + # out2 = self.BN(torch.cat(split[1:], dim=1).contiguous()) + out2 = self.BN(split[1].contiguous()) out = torch.cat((out1, out2), 1) return out diff --git a/modeling/losses/label_smooth.py b/modeling/losses/label_smooth.py new file mode 100644 index 0000000..ce06711 --- /dev/null +++ b/modeling/losses/label_smooth.py @@ -0,0 +1,37 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +import torch +from torch import nn + + +class CrossEntropyLabelSmooth(nn.Module): + """Cross entropy loss with label smoothing regularizer. + Reference: + Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. + Equation: y = (1 - epsilon) * y + epsilon / K. + Args: + num_classes (int): number of classes. + epsilon (float): weight. + """ + def __init__(self, num_classes, epsilon=0.1): + super(CrossEntropyLabelSmooth, self).__init__() + self.num_classes = num_classes + self.epsilon = epsilon + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, targets): + """ + Args: + inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) + targets: ground truth labels with shape (num_classes) + """ + log_probs = self.logsoftmax(inputs) + targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) + targets = targets.to(inputs.device) + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes + loss = (-targets * log_probs).mean(0).sum() + return loss diff --git a/modeling/losses/loss.py b/modeling/losses/loss.py index a17596a..eff763e 100644 --- a/modeling/losses/loss.py +++ b/modeling/losses/loss.py @@ -6,31 +6,23 @@ from torch import nn from .triplet_loss import TripletLoss - +from .label_smooth import CrossEntropyLabelSmooth __all__ = ['reidLoss'] class reidLoss(nn.Module): - def __init__(self, lossType:list, margin:float): + def __init__(self, lossType:list, margin:float, num_classes:float): super().__init__() self.lossType = lossType - self.ce_loss = nn.CrossEntropyLoss(reduction='none') - self.triplet_loss = TripletLoss(margin) + if 'softmax' in self.lossType: self.ce_loss = nn.CrossEntropyLoss() + if 'softmax_smooth' in self.lossType: self.ce_loss = CrossEntropyLabelSmooth(num_classes) + if 'triplet' in self.lossType: self.triplet_loss = TripletLoss(margin) def forward(self, out, target): scores, feats = out loss = 0 - if 'softmax' in self.lossType: - if len(target.size()) == 2: - loss1, loss2 = self.ce_loss(scores, target[:,0].long()), self.ce_loss(scores, target[:,1].long()) - d = loss1 * target[:,2] + loss2 * (1-target[:,2]) - else: - d = self.ce_loss(scores, target) - loss += d.mean() - if 'triplet' in self.lossType: - if len(target.size()) == 2: loss += self.triplet_loss(feats, target[:,0].long())[0] - else: loss += self.triplet_loss(feats, target)[0] - + if 'softmax' or 'softmax_smooth' in self.lossType: loss += self.ce_loss(scores, target) + if 'triplet' in self.lossType: loss += self.triplet_loss(feats, target)[0] return loss diff --git a/scripts/train_duke.sh b/scripts/train_duke.sh index 656b32a..f96b98b 100644 --- a/scripts/train_duke.sh +++ b/scripts/train_duke.sh @@ -1,8 +1,24 @@ +#!/usr/bin/env bash gpu=0 -CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ +CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ DATASETS.NAMES '("duke",)' \ DATASETS.TEST_NAMES 'duke' \ MODEL.BACKBONE 'resnet50' \ -MODEL.IBN 'True' \ -OUTPUT_DIR 'logs/test' +MODEL.WITH_IBN 'True' \ +MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +MODEL.STAGE_WITH_GCB '(False, False, False, False)' \ +SOLVER.LOSSTYPE '("softmax_smooth", "triplet")' \ +OUTPUT_DIR 'logs/2019.8.25/duke/ibn_smooth' + +#CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ +#DATASETS.NAMES '("duke",)' \ +#DATASETS.TEST_NAMES 'duke' \ +#MODEL.BACKBONE 'resnet50' \ +#MODEL.WITH_IBN 'True' \ +#MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +#MODEL.STAGE_WITH_GCB '(False, True, True, True)' \ +#SOLVER.LOSSTYPE '("softmax_smooth", "triplet")' \ +#OUTPUT_DIR 'logs/2019.8.25/duke/ibn_11_gcnet_smooth' + + diff --git a/scripts/train_market.sh b/scripts/train_market.sh index 9f851e4..466406f 100644 --- a/scripts/train_market.sh +++ b/scripts/train_market.sh @@ -1,3 +1,4 @@ +#!/usr/bin/env bash gpu=0 # CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ @@ -44,13 +45,15 @@ gpu=0 # MODEL.IBN 'False' \ # OUTPUT_DIR 'logs/2019.8.20/duke/resnet_softmax_triplet' -# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ -# DATASETS.NAMES '("duke",)' \ -# DATASETS.TEST_NAMES 'duke' \ -# MODEL.BACKBONE 'resnet50' \ -# MODEL.IBN 'True' \ -# MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ -# OUTPUT_DIR 'logs/2019.8.20/duke/resnet_ibn_softmax' + CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ + DATASETS.NAMES '("market1501",)' \ + DATASETS.TEST_NAMES 'market1501' \ + MODEL.BACKBONE 'resnet50' \ + MODEL.WITH_IBN 'True' \ + MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ + MODEL.STAGE_WITH_GCB '(False, False, False, False)' \ + SOLVER.LOSSTYPE '("softmax_smooth", "triplet")' \ + OUTPUT_DIR 'logs/2019.8.25/market/ibn_smooth' CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ DATASETS.NAMES '("market1501",)' \ @@ -59,4 +62,15 @@ MODEL.BACKBONE 'resnet50' \ MODEL.WITH_IBN 'True' \ MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ MODEL.STAGE_WITH_GCB '(False, True, True, True)' \ -OUTPUT_DIR 'logs/2019.8.22/duke/resnet_ibn_gc_softmax_triplet' \ No newline at end of file +SOLVER.LOSSTYPE '("softmax_smooth", "triplet")' \ +OUTPUT_DIR 'logs/2019.8.25/market/ibn_gc_smooth' + +CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ +DATASETS.NAMES '("duke",)' \ +DATASETS.TEST_NAMES 'duke' \ +MODEL.BACKBONE 'resnet50' \ +MODEL.WITH_IBN 'True' \ +MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +MODEL.STAGE_WITH_GCB '(False, False, False, False)' \ +SOLVER.LOSSTYPE '("softmax_smooth", "triplet")' \ +OUTPUT_DIR 'logs/2019.8.25/duke/ibn_smooth' diff --git a/tools/train.py b/tools/train.py index dff67df..2b5fc46 100644 --- a/tools/train.py +++ b/tools/train.py @@ -5,7 +5,6 @@ """ import argparse -import os from bisect import bisect_right from torch.backends import cudnn @@ -16,7 +15,6 @@ from data import get_data_bunch from engine.trainer import do_train from fastai.vision import * from modeling import * -from solver import * from utils.logger import setup_logger @@ -43,7 +41,7 @@ def train(cfg): lr_sched = Scheduler(cfg.SOLVER.BASE_LR, cfg.SOLVER.MAX_EPOCHS, lr_multistep) - loss_func = reidLoss(cfg.SOLVER.LOSSTYPE, cfg.SOLVER.MARGIN) + loss_func = reidLoss(cfg.SOLVER.LOSSTYPE, cfg.SOLVER.MARGIN, data_bunch.c) do_train( cfg,