From 3d5f7d24aa87ea5d822fbd56ecd3f034d0c93803 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Wed, 21 Aug 2019 09:35:34 +0800 Subject: [PATCH] 1. fix minor bug 2. update experiment results in readme --- README.md | 17 ++++++--- config/defaults.py | 4 ++- configs/softmax.yml | 15 ++++---- configs/softmax_triplet.yml | 3 +- data/build.py | 2 +- layers/__init__.py | 26 +------------- layers/loss.py | 28 +++++++++++++++ modeling/backbones/resnet.py | 5 ++- scripts/train_duke.sh | 2 +- scripts/train_market.sh | 68 +++++++++++++++++++++++++++--------- tools/train.py | 4 +-- 11 files changed, 112 insertions(+), 62 deletions(-) create mode 100644 layers/loss.py diff --git a/README.md b/README.md index 4bafc2d..23c8eed 100644 --- a/README.md +++ b/README.md @@ -70,9 +70,18 @@ You can test your model's performance directly by running this command python3 tools/test.py --config_file='configs/softmax.yml' TEST.WEIGHT '/save/trained_model/path' ``` -## Results +## Experiment Results + +| size=(256, 128) batch_size=64 (16 id x 4 imgs) | | | | | | +| :------: | :-----: | :-----: | :--: | :---: | :----: | +| softmax? | ✔︎ | ✔︎ | ✔︎ | ✔︎ | ✔︎ | +| triplet? | | ✔︎ | | ✔︎ | ✔︎ | +| ibn? | | | ✔︎ | ✔︎ | ✔︎ | +| gcnet? | | | | | ✔︎ | +| Market1501 | 93.4 (82.9) | 94.2 (86.1) |93.3 (84.3)|94.9 (86.4)|-| +| DukeMTMC-reid | 84.7 (72.7) | 87.3 (76.0) |86.7 (74.9)|87.9 (77.1)|-| +| CUHK03 | | |||| -| cfg | market1501 | dukemtmc | -| --- | -- | -- | -| softmax+triplet, size=(256, 128), batch_size=64(16 id x 4 imgs) | 93.9 (85.9) | 86.5 (75.9) | + +🔥Any other tricks are welcomed! \ No newline at end of file diff --git a/config/defaults.py b/config/defaults.py index aa1fb5a..da1beee 100644 --- a/config/defaults.py +++ b/config/defaults.py @@ -63,7 +63,7 @@ _C.INPUT.P_LIGHTING=0.75 # ----------------------------------------------------------------------------- _C.DATASETS = CN() # List of the dataset names for training -_C.DATASETS.NAMES = ("cuhk03",) +_C.DATASETS.NAMES = ("market1501",) # List of the dataset names for testing _C.DATASETS.TEST_NAMES = "market1501" @@ -84,6 +84,8 @@ _C.DATALOADER.NUM_INSTANCE = 16 _C.SOLVER = CN() _C.SOLVER.OPT = "adam" +_C.SOLVER.LOSSTYPE = ("softmax",) + _C.SOLVER.MAX_EPOCHS = 50 _C.SOLVER.BASE_LR = 3e-4 diff --git a/configs/softmax.yml b/configs/softmax.yml index 54e0516..51cd785 100644 --- a/configs/softmax.yml +++ b/configs/softmax.yml @@ -1,22 +1,23 @@ MODEL: - PRETRAIN_PATH: 'home/user01/.torch/models/resnet50-19c8e357.pth' - + BACKBONE: "resnet50" INPUT: - SIZE_TRAIN: [384, 128] - SIZE_TEST: [384, 128] - PROB: 0.5 # random horizontal flip + SIZE_TRAIN: [256, 128] + SIZE_TEST: [256, 128] + FLIP_PROB: 0.5 # random horizontal flip PADDING: 10 DATASETS: - NAMES: ('market1501') + NAMES: ("market1501",) + TEST_NAMES: 'market1501' DATALOADER: SAMPLER: 'softmax' NUM_WORKERS: 8 SOLVER: - OPTIMIZER_NAME: 'Adam' + OPT: 'adam' + LOSSTYPE: ('softmax',) MAX_EPOCHS: 120 BASE_LR: 0.00035 BIAS_LR_FACTOR: 1 diff --git a/configs/softmax_triplet.yml b/configs/softmax_triplet.yml index 5f5d901..cd10fb9 100644 --- a/configs/softmax_triplet.yml +++ b/configs/softmax_triplet.yml @@ -13,11 +13,12 @@ DATASETS: TEST_NAMES: "market1501" DATALOADER: - SAMPLER: 'softmax_triplet' + SAMPLER: 'triplet' NUM_INSTANCE: 4 SOLVER: OPT: 'adam' + LOSSTYPE: ('softmax', 'triplet') MAX_EPOCHS: 150 BASE_LR: 0.00035 WEIGHT_DECAY: 0.0005 diff --git a/data/build.py b/data/build.py index 51a7d84..379ca33 100644 --- a/data/build.py +++ b/data/build.py @@ -78,7 +78,7 @@ def get_data_bunch(cfg): size=cfg.INPUT.SIZE_TRAIN, ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH, val_bs=cfg.TEST.IMS_PER_BATCH) - if 'triplet' in cfg.DATALOADER.SAMPLER: + if cfg.DATALOADER.SAMPLER == 'triplet': data_sampler = RandomIdentitySampler(train_names, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) data_bunch.train_dl = data_bunch.train_dl.new(shuffle=False, sampler=data_sampler) diff --git a/layers/__init__.py b/layers/__init__.py index 0127fe1..80565fd 100644 --- a/layers/__init__.py +++ b/layers/__init__.py @@ -4,28 +4,4 @@ @contact: sherlockliao01@gmail.com """ -import torch.nn.functional as F - -from .triplet_loss import TripletLoss - - -def make_loss(cfg): - sampler = cfg.DATALOADER.SAMPLER - triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss - - if sampler == 'softmax': - def loss_func(out, target): - score, feat = out - return F.cross_entropy(score, target) - elif cfg.DATALOADER.SAMPLER == 'triplet': - def loss_func(out, target): - score, feat = out - return triplet(feat, target)[0] - elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': - def loss_func(out, target): - score, feat = out - return F.cross_entropy(score, target) + triplet(feat, target)[0] - else: - print('expected sampler should be softmax, triplet or softmax_triplet, ' - 'but got {}'.format(cfg.DATALOADER.SAMPLER)) - return loss_func +from .loss import reidLoss \ No newline at end of file diff --git a/layers/loss.py b/layers/loss.py new file mode 100644 index 0000000..6fd3ad4 --- /dev/null +++ b/layers/loss.py @@ -0,0 +1,28 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" +from torch import nn + +from .triplet_loss import TripletLoss + + +__all__ = ['reidLoss'] + + +class reidLoss(nn.Module): + def __init__(self, lossType:list, margin:float): + super().__init__() + self.lossType = lossType + + self.ce_loss = nn.CrossEntropyLoss() + self.triplet_loss = TripletLoss(margin) + + def forward(self, out, target): + scores, feats = out + loss = 0 + if 'softmax' in self.lossType: loss += self.ce_loss(scores, target) + if 'triplet' in self.lossType: loss += self.triplet_loss(feats, target)[0] + + return loss diff --git a/modeling/backbones/resnet.py b/modeling/backbones/resnet.py index e9e9902..d48ed01 100644 --- a/modeling/backbones/resnet.py +++ b/modeling/backbones/resnet.py @@ -33,7 +33,7 @@ __all__ = ['ResNet'] class IBN(nn.Module): def __init__(self, planes): super(IBN, self).__init__() - half1 = int(planes/2) + half1 = int(planes/8) self.half = half1 half2 = planes - half1 self.IN = nn.InstanceNorm2d(half1, affine=True) @@ -42,8 +42,7 @@ class IBN(nn.Module): def forward(self, x): split = torch.split(x, self.half, 1) out1 = self.IN(split[0].contiguous()) - # out2 = self.BN(torch.cat(split[1:], dim=1).contiguous()) - out2 = self.BN(split[1].contiguous()) + out2 = self.BN(torch.cat(split[1:], dim=1).contiguous()) out = torch.cat((out1, out2), 1) return out diff --git a/scripts/train_duke.sh b/scripts/train_duke.sh index aeca80f..89d74b1 100644 --- a/scripts/train_duke.sh +++ b/scripts/train_duke.sh @@ -1,6 +1,6 @@ gpu=0 -CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ +CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ DATASETS.NAMES '("duke",)' \ DATASETS.TEST_NAMES 'duke' \ MODEL.BACKBONE 'resnet50' \ diff --git a/scripts/train_market.sh b/scripts/train_market.sh index 61dc406..477ba1c 100644 --- a/scripts/train_market.sh +++ b/scripts/train_market.sh @@ -1,27 +1,61 @@ gpu=0 +CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ +DATASETS.NAMES '("market1501",)' \ +DATASETS.TEST_NAMES 'market1501' \ +MODEL.BACKBONE 'resnet50' \ +MODEL.IBN 'False' \ +OUTPUT_DIR 'logs/2019.8.20/market/resnet_softmax' + CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ DATASETS.NAMES '("market1501",)' \ DATASETS.TEST_NAMES 'market1501' \ MODEL.BACKBONE 'resnet50' \ MODEL.IBN 'False' \ -INPUT.DO_LIGHTING 'False' \ -SOLVER.OPT 'radam' \ -OUTPUT_DIR 'logs/2019.8.17/market/resnet_radam_nowarmup' +OUTPUT_DIR 'logs/2019.8.20/market/resnet_softmax_triplet' -# MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ +DATASETS.NAMES '("market1501",)' \ +DATASETS.TEST_NAMES 'market1501' \ +MODEL.BACKBONE 'resnet50' \ +MODEL.IBN 'True' \ +MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +OUTPUT_DIR 'logs/2019.8.20/market/resnet_ibn_softmax' -# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ -# DATASETS.NAMES '("market1501",)' \ -# DATASETS.TEST_NAMES 'market1501' \ -# MODEL.BACKBONE 'resnet50_ibn' \ -# MODEL.PRETRAIN_PATH '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ -# INPUT.DO_LIGHTING 'False' \ -# OUTPUT_DIR 'logs/2019.8.13/market/ibn7_1' +CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ +DATASETS.NAMES '("market1501",)' \ +DATASETS.TEST_NAMES 'market1501' \ +MODEL.BACKBONE 'resnet50' \ +MODEL.IBN 'True' \ +MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +OUTPUT_DIR 'logs/2019.8.20/market/resnet_ibn_softmax_triplet' -# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ -# DATASETS.NAMES '("market1501",)' \ -# DATASETS.TEST_NAMES 'market1501' \ -# SOLVER.IMS_PER_BATCH '64' \ -# INPUT.DO_LIGHTING 'True' \ -# OUTPUT_DIR 'logs/market/bs64' \ No newline at end of file +CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ +DATASETS.NAMES '("duke",)' \ +DATASETS.TEST_NAMES 'duke' \ +MODEL.BACKBONE 'resnet50' \ +MODEL.IBN 'False' \ +OUTPUT_DIR 'logs/2019.8.20/duke/resnet_softmax' + +CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ +DATASETS.NAMES '("duke",)' \ +DATASETS.TEST_NAMES 'duke' \ +MODEL.BACKBONE 'resnet50' \ +MODEL.IBN 'False' \ +OUTPUT_DIR 'logs/2019.8.20/duke/resnet_softmax_triplet' + +CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ +DATASETS.NAMES '("duke",)' \ +DATASETS.TEST_NAMES 'duke' \ +MODEL.BACKBONE 'resnet50' \ +MODEL.IBN 'True' \ +MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +OUTPUT_DIR 'logs/2019.8.20/duke/resnet_ibn_softmax' + +CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ +DATASETS.NAMES '("duke",)' \ +DATASETS.TEST_NAMES 'duke' \ +MODEL.BACKBONE 'resnet50' \ +MODEL.IBN 'True' \ +MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +OUTPUT_DIR 'logs/2019.8.20/duke/resnet_ibn_softmax_triplet' \ No newline at end of file diff --git a/tools/train.py b/tools/train.py index 16fcd68..ffbc9cc 100644 --- a/tools/train.py +++ b/tools/train.py @@ -15,7 +15,7 @@ from config import cfg from data import get_data_bunch from engine.trainer import do_train from fastai.vision import * -from layers import make_loss +from layers import reidLoss from modeling import build_model from solver import * from utils.logger import setup_logger @@ -44,7 +44,7 @@ def train(cfg): lr_sched = Scheduler(cfg.SOLVER.BASE_LR, cfg.SOLVER.MAX_EPOCHS, lr_multistep) - loss_func = make_loss(cfg) + loss_func = reidLoss(cfg.SOLVER.LOSSTYPE, cfg.SOLVER.MARGIN) do_train( cfg,