From 9f1fdbdda8a4078f8ae0be2550d2985498dc5775 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Fri, 8 Jun 2018 12:59:03 +0800 Subject: [PATCH] first update --- README.md | 16 +- bases/__init__.py | 11 + bases/base_evaluator.py | 75 ++++++ bases/base_trainer.py | 70 ++++++ config.py | 59 +++++ datasets/data_loader.py | 34 +++ datasets/data_manager.py | 109 +++++++++ datasets/samplers.py | 32 +++ main_reid.py | 410 ++++++++++++++++++++++++++++++++ models/__init__.py | 12 + models/resnet.py | 121 ++++++++++ models/resnet_reid.py | 67 ++++++ scripts/train_classification.sh | 3 + trainers/__init__.py | 13 + trainers/evaluator.py | 84 +++++++ trainers/trainer.py | 54 +++++ utils/__init__.py | 11 + utils/loss.py | 154 ++++++++++++ utils/meters.py | 50 ++++ utils/serialization.py | 74 ++++++ utils/transforms.py | 80 +++++++ utils/validation_metrics.py | 25 ++ 22 files changed, 1563 insertions(+), 1 deletion(-) create mode 100644 bases/__init__.py create mode 100644 bases/base_evaluator.py create mode 100644 bases/base_trainer.py create mode 100644 config.py create mode 100755 datasets/data_loader.py create mode 100755 datasets/data_manager.py create mode 100755 datasets/samplers.py create mode 100644 main_reid.py create mode 100644 models/__init__.py create mode 100644 models/resnet.py create mode 100644 models/resnet_reid.py create mode 100644 scripts/train_classification.sh create mode 100644 trainers/__init__.py create mode 100644 trainers/evaluator.py create mode 100644 trainers/trainer.py create mode 100644 utils/__init__.py create mode 100644 utils/loss.py create mode 100644 utils/meters.py create mode 100644 utils/serialization.py create mode 100644 utils/transforms.py create mode 100644 utils/validation_metrics.py diff --git a/README.md b/README.md index 24f82cc..753e60c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,16 @@ # reid_baseline -reid baseline model for exploring softmax and triplet hard loss +reid baseline model for exploring softmax and triplet hard loss's influence. + +## Configuration + +### Classification +resnet lr: 0.1 +classifier lr: 0.01 + +### Triplet Hard +lr: 2e-4 + +### Classification + Triplet Hard +lr: 2e-4 + +exponetional decay at 150 \ No newline at end of file diff --git a/bases/__init__.py b/bases/__init__.py new file mode 100644 index 0000000..4ec04de --- /dev/null +++ b/bases/__init__.py @@ -0,0 +1,11 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + diff --git a/bases/base_evaluator.py b/bases/base_evaluator.py new file mode 100644 index 0000000..4ac183a --- /dev/null +++ b/bases/base_evaluator.py @@ -0,0 +1,75 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np +import torch + + +class BaseEvaluator(object): + def __init__(self, model): + self.model = model + + def evaluate(self, queryloader, galleryloader, ranks=[1, 5, 10, 20]): + self.model.eval() + qf, q_pids, q_camids = [], [], [] + for batch_idx, inputs in enumerate(queryloader): + inputs, pids, camids = self._parse_data(inputs) + feature = self._forward(inputs) + qf.append(feature) + q_pids.extend(pids) + q_camids.extend(camids) + qf = torch.cat(qf, 0) + q_pids = np.asarray(q_pids) + q_camids = np.asarray(q_camids) + + print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1))) + + gf, g_pids, g_camids = [], [], [] + for batch_idx, inputs in enumerate(galleryloader): + inputs, pids, camids = self._parse_data(inputs) + feature = self._forward(inputs) + gf.append(feature) + g_pids.extend(pids) + g_camids.extend(camids) + gf = torch.cat(gf, 0) + g_pids = np.asarray(g_pids) + g_camids = np.asarray(g_camids) + + print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1))) + + print("Computing distance matrix") + + m, n = qf.size(0), gf.size(0) + distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ + torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() + distmat.addmm_(1, -2, qf, gf.t()) + distmat = distmat.numpy() + + print("Computing CMC and mAP") + cmc, mAP = self.eval_func(distmat, q_pids, g_pids, q_camids, g_camids) + + print("Results ----------") + print("mAP: {:.1%}".format(mAP)) + print("CMC curve") + for r in ranks: + print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1])) + print("------------------") + + return cmc[0] + + def _parse_data(self, inputs): + raise NotImplementedError + + def _forward(self, inputs): + raise NotImplementedError + + def eval_func(self, distmat, q_pids, g_pids, q_camids, g_camids): + raise NotImplementedError diff --git a/bases/base_trainer.py b/bases/base_trainer.py new file mode 100644 index 0000000..acaab41 --- /dev/null +++ b/bases/base_trainer.py @@ -0,0 +1,70 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import time + +from utils.meters import AverageMeter + + +class BaseTrainer(object): + def __init__(self, model, criterion, tb_writer): + self.model = model + self.criterion = criterion + self.tb_writer = tb_writer + + def train(self, epoch, data_loader, optimizer, print_freq=1): + self.model.train() + + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + + start = time.time() + for i, inputs in enumerate(data_loader): + data_time.update(time.time() - start) + + # model optimizer + inputs, targets = self._parse_data(inputs) + loss = self._forward(inputs, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + batch_time.update(time.time() - start) + losses.update(loss.item()) + + # tensorboard + global_step = epoch * len(data_loader) + i + self.tb_writer.add_scalar('loss', loss.item(), global_step) + self.tb_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step) + + start = time.time() + + if (i + 1) % print_freq == 0: + print('Epoch: [{}][{}/{}]\t' + 'Batch Time {:.3f} ({:.3f})\t' + 'Data Time {:.3f} ({:.3f})\t' + 'Loss {:.3f} ({:.3f})\t' + .format(epoch, i + 1, len(data_loader), + batch_time.val, batch_time.mean, + data_time.val, data_time.mean, + losses.val, losses.mean)) + param_group = optimizer.param_groups + print('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3e}\t' + 'Lr {:.2e}' + .format(epoch, batch_time.sum, losses.mean, param_group[0]['lr'])) + print() + + def _parse_data(self, inputs): + raise NotImplementedError + + def _forward(self, inputs, targets): + raise NotImplementedError diff --git a/config.py b/config.py new file mode 100644 index 0000000..8882d2c --- /dev/null +++ b/config.py @@ -0,0 +1,59 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import warnings + + +class DefaultConfig(object): + seed = 0 + + # dataset options + dataset = 'market' + height = 384 + width = 192 + + # optimization options + optim = 'SGD' + max_epoch = 100 + train_batch = 128 + test_batch = 128 + lr = 0.1 + step_size = 60 + gamma = 0.1 + weight_decay = 5e-4 + momentum = 0.9 + margin = 0.3 + num_instances = 4 + + # model options + model_name = 'ResNetBuilder' + last_stride = 1 + + # miscs + print_freq = 30 + eval_step = 50 + save_dir = '/DATA/pytorch-ckpt/market' + gpu = 0, 1 + workers = 10 + start_epoch = 0 + + def _parse(self, kwargs): + for k, v in kwargs.items(): + if not hasattr(self, k): + warnings.warn("Warning: opt has not attribut %s" % k) + setattr(self, k, v) + + def _state_dict(self): + return {k: getattr(self, k) for k, _ in DefaultConfig.__dict__.items() + if not k.startswith('_')} + + +args = DefaultConfig() diff --git a/datasets/data_loader.py b/datasets/data_loader.py new file mode 100755 index 0000000..69888f2 --- /dev/null +++ b/datasets/data_loader.py @@ -0,0 +1,34 @@ +from __future__ import print_function, absolute_import + +from PIL import Image +from torch.utils.data import Dataset + + +def read_image(img_path): + """Keep reading image until succeed. + This can avoid IOError incurred by heavy IO process.""" + got_img = False + while not got_img: + try: + img = Image.open(img_path).convert('RGB') + got_img = True + except IOError: + print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) + pass + return img + + +class ImageData(Dataset): + def __init__(self, dataset, transform): + self.dataset = dataset + self.transform = transform + + def __getitem__(self, item): + img, pid, camid = self.dataset[item] + img = read_image(img) + if self.transform is not None: + img = self.transform(img) + return img, pid, camid + + def __len__(self): + return len(self.dataset) diff --git a/datasets/data_manager.py b/datasets/data_manager.py new file mode 100755 index 0000000..e694279 --- /dev/null +++ b/datasets/data_manager.py @@ -0,0 +1,109 @@ +from __future__ import print_function, absolute_import + +import glob +import re +from os import path as osp + +"""Dataset classes""" + + +class Market1501(object): + """ + Market1501 + Reference: + Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. + URL: http://www.liangzheng.org/Project/project_reid.html + + Dataset statistics: + # identities: 1501 (+1 for background) + # images: 12936 (train) + 3368 (query) + 15913 (gallery) + """ + dataset_dir = 'market1501' + + def __init__(self, root='/home/liaoxingyu/', **kwargs): + self.dataset_dir = osp.join(root, self.dataset_dir) + self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') + self.query_dir = osp.join(self.dataset_dir, 'query') + self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') + + self._check_before_run() + + train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) + query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) + gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) + num_total_pids = num_train_pids + num_query_pids + num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs + + print("=> Market1501 loaded") + print("Dataset statistics:") + print(" ------------------------------") + print(" subset | # ids | # images") + print(" ------------------------------") + print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) + print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) + print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) + print(" ------------------------------") + print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) + print(" ------------------------------") + + self.train = train + self.query = query + self.gallery = gallery + + self.num_train_pids = num_train_pids + self.num_query_pids = num_query_pids + self.num_gallery_pids = num_gallery_pids + + def _check_before_run(self): + """Check if all files are available before going deeper""" + if not osp.exists(self.dataset_dir): + raise RuntimeError("'{}' is not available".format(self.dataset_dir)) + if not osp.exists(self.train_dir): + raise RuntimeError("'{}' is not available".format(self.train_dir)) + if not osp.exists(self.query_dir): + raise RuntimeError("'{}' is not available".format(self.query_dir)) + if not osp.exists(self.gallery_dir): + raise RuntimeError("'{}' is not available".format(self.gallery_dir)) + + def _process_dir(self, dir_path, relabel=False): + img_paths = glob.glob(osp.join(dir_path, '*.jpg')) + pattern = re.compile(r'([-\d]+)_c(\d)') + + pid_container = set() + for img_path in img_paths: + pid, _ = map(int, pattern.search(img_path).groups()) + if pid == -1: continue # junk images are just ignored + pid_container.add(pid) + pid2label = {pid: label for label, pid in enumerate(pid_container)} + + dataset = [] + for img_path in img_paths: + pid, camid = map(int, pattern.search(img_path).groups()) + if pid == -1: + continue # junk images are just ignored + assert 0 <= pid <= 1501 # pid == 0 means background + assert 1 <= camid <= 6 + camid -= 1 # index starts from 0 + if relabel: pid = pid2label[pid] + dataset.append((img_path, pid, camid)) + + num_pids = len(pid_container) + num_imgs = len(dataset) + return dataset, num_pids, num_imgs + + +"""Create datasets""" + +__factory = { + 'market': Market1501 +} + + +def get_names(): + return __factory.keys() + + +def init_dataset(name, *args, **kwargs): + if name not in __factory.keys(): + raise KeyError("Unknown datasets: {}".format(name)) + return __factory[name](*args, **kwargs) diff --git a/datasets/samplers.py b/datasets/samplers.py new file mode 100755 index 0000000..590e8fa --- /dev/null +++ b/datasets/samplers.py @@ -0,0 +1,32 @@ +from __future__ import absolute_import + +from collections import defaultdict + +import numpy as np +import torch +from torch.utils.data.sampler import Sampler + + +class RandomIdentitySampler(Sampler): + def __init__(self, data_source, num_instances=4): + self.data_source = data_source + self.num_instances = num_instances + self.index_dic = defaultdict(list) + for index, (_, pid, _) in enumerate(data_source): + self.index_dic[pid].append(index) + self.pids = list(self.index_dic.keys()) + self.num_identities = len(self.pids) + + def __iter__(self): + indices = torch.randperm(self.num_identities) + ret = [] + for i in indices: + pid = self.pids[i] + t = self.index_dic[pid] + replace = False if len(t) >= self.num_instances else True + t = np.random.choice(t, size=self.num_instances, replace=replace) + ret.extend(t) + return iter(ret) + + def __len__(self): + return self.num_identities * self.num_instances diff --git a/main_reid.py b/main_reid.py new file mode 100644 index 0000000..d9bed12 --- /dev/null +++ b/main_reid.py @@ -0,0 +1,410 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import os +import sys +from os import path as osp +from pprint import pprint + +import numpy as np +import torch +from datasets.samplers import RandomIdentitySampler +from tensorboardX import SummaryWriter +from torch import nn +from torch.backends import cudnn +from torch.utils.data import DataLoader + +from config import args +from datasets import data_manager +from datasets.data_loader import ImageData +from models import ResNetBuilder +from trainers import ResNetClsTrainer, ResNetTriTrainer, ResNetClsTriTrainer, ResNetEvaluator +from utils.loss import TripletLoss +from utils.serialization import Logger +from utils.serialization import save_checkpoint +from utils.transforms import TrainTransform, TestTransform + + +def train_classification(**kwargs): + args._parse(kwargs) + + # set random seed and cudnn benchmark + torch.manual_seed(args.seed) + + use_gpu = torch.cuda.is_available() + sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) + + print('=========user config==========') + pprint(args._state_dict()) + print('============end===============') + + if use_gpu: + print('currently using GPU {}'.format(args.gpu)) + cudnn.benchmark = True + torch.cuda.manual_seed_all(args.seed) + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + else: + print('currently using cpu') + + print('initializing dataset {}'.format(args.dataset)) + dataset = data_manager.init_dataset(name=args.dataset) + + pin_memory = True if use_gpu else False + + tb_writer = SummaryWriter(osp.join(args.save_dir, 'tb_log')) + + trainloader = DataLoader( + ImageData(dataset.train, TrainTransform(args.height, args.width)), + batch_size=args.train_batch, shuffle=True, num_workers=args.workers, + pin_memory=pin_memory, drop_last=True + ) + + queryloader = DataLoader( + ImageData(dataset.query, TestTransform(args.height, args.width)), + batch_size=args.test_batch, num_workers=args.workers, + pin_memory=pin_memory + ) + + galleryloader = DataLoader( + ImageData(dataset.gallery, TestTransform(args.height, args.width)), + batch_size=args.test_batch, num_workers=args.workers, + pin_memory=pin_memory + ) + + print('initializing model ...') + model = ResNetBuilder(num_classes=dataset.num_train_pids) + print('model size: {:.5f}M'.format(sum(p.numel() + for p in model.parameters()) / 1e6)) + + cls_criterion = nn.CrossEntropyLoss() + + def xent_criterion(cls_scores, targets): + cls_loss = cls_criterion(cls_scores, targets) + return cls_loss + + # get optimizer + optimizer = torch.optim.SGD( + model.optim_policy(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum + ) + + def adjust_lr(optimizer, ep, decay_ep, gamma): + decay = gamma ** float(ep // decay_ep) + for g in optimizer.param_groups: + g['lr'] = args.lr * decay * g.get('lr_multi', 1) + + start_epoch = args.start_epoch + if use_gpu: + model = nn.DataParallel(model).cuda() + + # get trainer and evaluator + reid_trainer = ResNetClsTrainer(model, xent_criterion, tb_writer) + reid_evaluator = ResNetEvaluator(model) + + # start training + best_rank1 = -np.inf + best_epoch = 0 + for epoch in range(start_epoch, args.max_epoch): + if args.step_size > 0: + adjust_lr(optimizer, epoch + 1, args.step_size, args.gamma) + reid_trainer.train(epoch, trainloader, optimizer, args.print_freq) + + # skip if not save model + if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch: + rank1 = reid_evaluator.evaluate(queryloader, galleryloader) + is_best = rank1 > best_rank1 + if is_best: + best_rank1 = rank1 + best_epoch = epoch + 1 + + if use_gpu: + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + save_checkpoint({ + 'state_dict': state_dict, + 'epoch': epoch + 1, + }, is_best=is_best, save_dir=args.save_dir, filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar') + + print( + 'Best rank-1 {:.1%}, achived at epoch {}'.format(best_rank1, best_epoch)) + + +def train_triplet(**kwargs): + args._parse(kwargs) + + # set random seed and cudnn benchmark + torch.manual_seed(args.seed) + + use_gpu = torch.cuda.is_available() + sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) + + print('=========user config==========') + pprint(args._state_dict()) + print('============end===============') + + if use_gpu: + print('currently using GPU {}'.format(args.gpu)) + cudnn.benchmark = True + torch.cuda.manual_seed_all(args.seed) + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + else: + print('currently using cpu') + + print('initializing dataset {}'.format(args.dataset)) + dataset = data_manager.init_dataset(name=args.dataset) + + pin_memory = True if use_gpu else False + + tb_writer = SummaryWriter(osp.join(args.save_dir, 'tb_log')) + + trainloader = DataLoader( + ImageData(dataset.train, TrainTransform(args.height, args.width)), + sampler=RandomIdentitySampler(dataset.train, args.num_instances), + batch_size=args.train_batch, num_workers=args.workers, + pin_memory=pin_memory, drop_last=True + ) + + queryloader = DataLoader( + ImageData(dataset.query, TestTransform(args.height, args.width)), + batch_size=args.test_batch, num_workers=args.workers, + pin_memory=pin_memory + ) + + galleryloader = DataLoader( + ImageData(dataset.gallery, TestTransform(args.height, args.width)), + batch_size=args.test_batch, num_workers=args.workers, + pin_memory=pin_memory + ) + + print('initializing model ...') + model = ResNetBuilder() + print('model size: {:.5f}M'.format(sum(p.numel() + for p in model.parameters()) / 1e6)) + + tri_criterion = TripletLoss(margin=args.margin) + + def tri_hard(feat, targets): + tri_loss, _, _ = tri_criterion(feat, targets) + return tri_loss + + # get optimizer + optimizer = torch.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + + def adjust_lr_exp(optimizer, base_lr, ep, total_ep, start_decay_ep, gamma): + if ep < start_decay_ep: + return + lr_decay = gamma ** (float(ep - start_decay_ep) / + (total_ep - start_decay_ep)) + for g in optimizer.param_groups: + g['lr'] = base_lr * lr_decay + + start_epoch = args.start_epoch + if use_gpu: + model = nn.DataParallel(model).cuda() + + # get trainer and evaluator + reid_trainer = ResNetTriTrainer(model, tri_hard, tb_writer) + reid_evaluator = ResNetEvaluator(model) + + # start training + best_rank1 = -np.inf + best_epoch = 0 + for epoch in range(start_epoch, args.max_epoch): + if args.step_size > 0: + adjust_lr_exp(optimizer, args.lr, epoch + 1, args.max_epoch, args.step_size, args.gamma) + reid_trainer.train(epoch, trainloader, optimizer, args.print_freq) + + # skip if not save model + if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch: + rank1 = reid_evaluator.evaluate(queryloader, galleryloader) + is_best = rank1 > best_rank1 + if is_best: + best_rank1 = rank1 + best_epoch = epoch + 1 + + if use_gpu: + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + save_checkpoint({ + 'state_dict': state_dict, + 'epoch': epoch + 1, + }, is_best=is_best, save_dir=args.save_dir, filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar') + + print( + 'Best rank-1 {:.1%}, achived at epoch {}'.format(best_rank1, best_epoch)) + + +def train_cls_triplet(**kwargs): + args._parse(kwargs) + + # set random seed and cudnn benchmark + torch.manual_seed(args.seed) + + use_gpu = torch.cuda.is_available() + sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) + + print('=========user config==========') + pprint(args._state_dict()) + print('============end===============') + + if use_gpu: + print('currently using GPU {}'.format(args.gpu)) + cudnn.benchmark = True + torch.cuda.manual_seed_all(args.seed) + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + else: + print('currently using cpu') + + print('initializing dataset {}'.format(args.dataset)) + dataset = data_manager.init_dataset(name=args.dataset) + + pin_memory = True if use_gpu else False + + tb_writer = SummaryWriter(osp.join(args.save_dir, 'tb_log')) + + trainloader = DataLoader( + ImageData(dataset.train, TrainTransform(args.height, args.width)), + sampler=RandomIdentitySampler(dataset.train, args.num_instances), + batch_size=args.train_batch, num_workers=args.workers, + pin_memory=pin_memory, drop_last=True + ) + + queryloader = DataLoader( + ImageData(dataset.query, TestTransform(args.height, args.width)), + batch_size=args.test_batch, num_workers=args.workers, + pin_memory=pin_memory + ) + + galleryloader = DataLoader( + ImageData(dataset.gallery, TestTransform(args.height, args.width)), + batch_size=args.test_batch, num_workers=args.workers, + pin_memory=pin_memory + ) + + print('initializing model ...') + model = ResNetBuilder(num_classes=dataset.num_train_pids) + print('model size: {:.5f}M'.format(sum(p.numel() + for p in model.parameters()) / 1e6)) + + cls_criterion = nn.CrossEntropyLoss() + tri_criterion = TripletLoss(margin=args.margin) + + def xent_tri_criterion(cls_scores, global_feat, targets): + cls_loss = cls_criterion(cls_scores, targets) + tri_loss, dist_ap, dist_an = tri_criterion(global_feat, targets) + loss = cls_loss + tri_loss + return loss + + # get optimizer + optimizer = torch.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + + def adjust_lr_exp(optimizer, base_lr, ep, total_ep, start_decay_ep, gamma): + if ep < start_decay_ep: + return + lr_decay = gamma ** (float(ep - start_decay_ep) / + (total_ep - start_decay_ep)) + for g in optimizer.param_groups: + g['lr'] = base_lr * lr_decay + + start_epoch = args.start_epoch + if use_gpu: + model = nn.DataParallel(model).cuda() + + # get trainer and evaluator + reid_trainer = ResNetClsTriTrainer(model, xent_tri_criterion, tb_writer) + reid_evaluator = ResNetEvaluator(model) + + # start training + best_rank1 = -np.inf + best_epoch = 0 + for epoch in range(start_epoch, args.max_epoch): + if args.step_size > 0: + adjust_lr_exp(optimizer, args.lr, epoch + 1, args.max_epoch, args.step_size, args.gamma) + reid_trainer.train(epoch, trainloader, optimizer, args.print_freq) + + # skip if not save model + if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch: + rank1 = reid_evaluator.evaluate(queryloader, galleryloader) + is_best = rank1 > best_rank1 + if is_best: + best_rank1 = rank1 + best_epoch = epoch + 1 + + if use_gpu: + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + save_checkpoint({ + 'state_dict': state_dict, + 'epoch': epoch + 1, + }, is_best=is_best, save_dir=args.save_dir, filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar') + + print( + 'Best rank-1 {:.1%}, achived at epoch {}'.format(best_rank1, best_epoch)) + + +def test(**kwargs): + args._parse(kwargs) + + # set random seed and cudnn benchmark + torch.manual_seed(args.seed) + + use_gpu = torch.cuda.is_available() + sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) + + if use_gpu: + print('currently using GPU {}'.format(args.gpu)) + cudnn.benchmark = True + torch.cuda.manual_seed_all(args.seed) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu + else: + print('currently using cpu') + + print('initializing dataset {}'.format(args.dataset)) + dataset = data_manager.init_dataset(name=args.dataset) + + pin_memory = True if use_gpu else False + + queryloader = DataLoader( + ImageData(dataset.query, TestTransform(args.height, args.width)), + batch_size=args.test_batch, num_workers=args.workers, + pin_memory=pin_memory + ) + + galleryloader = DataLoader( + ImageData(dataset.gallery, TestTransform(args.height, args.width)), + batch_size=args.test_batch, num_workers=args.workers, + pin_memory=pin_memory + ) + + print('loading model ...') + model = ResNetBuilder(num_classes=dataset.num_train_pids) + # ckpt = torch.load(args.load_model) + # model.load_state_dict(ckpt['state_dict']) + print('model size: {:.5f}M'.format(sum(p.numel() + for p in model.parameters()) / 1e6)) + + if use_gpu: + model = nn.DataParallel(model).cuda() + + reid_evaluator = ResNetEvaluator(model) + reid_evaluator.evaluate(queryloader, galleryloader) + + +if __name__ == '__main__': + import fire + + fire.Fire() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..68e22d9 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,12 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from .resnet_reid import ResNetBuilder diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000..f63fa5f --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,121 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: liaoxingyu@megvii.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import math + +import torch +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): + self.inplanes = 64 + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=last_stride) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + return x + + def load_param(self, model_path): + param_dict = torch.load(model_path) + for i in param_dict: + if 'fc' in i: + continue + self.state_dict()[i].copy_(param_dict[i]) + + def random_init(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +if __name__ == "__main__": + net = ResNet(last_stride=2) + import torch + x = net(torch.zeros(1, 3, 256, 128)) + print(x.shape) diff --git a/models/resnet_reid.py b/models/resnet_reid.py new file mode 100644 index 0000000..4b6f58f --- /dev/null +++ b/models/resnet_reid.py @@ -0,0 +1,67 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import torch.nn.functional as F +from torch import nn + +from .resnet import ResNet + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + nn.init.normal_(m.weight, std=0.001) + nn.init.constant_(m.bias, 0.0) + elif classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') + if hasattr(m, 'bias'): + nn.init.constant_(m.bias, 0.0) + elif classname.find('BatchNorm') != -1: + if m.affine: + nn.init.normal_(m.weight, 1.0, 0.02) + nn.init.constant_(m.bias, 0.0) + + +class ResNetBuilder(nn.Module): + in_planes = 2048 + + def __init__(self, num_classes=None, last_stride=1, model_path='/DATA/model_zoo/resnet50-19c8e357.pth'): + super().__init__() + self.base = ResNet(last_stride) + self.base.load_param(model_path) + self.bottleneck = nn.Sequential( + nn.BatchNorm2d(self.in_planes), + nn.ReLU(True) + ) + self.num_classes = num_classes + if num_classes is not None: + self.classifier = nn.Linear(self.in_planes, num_classes) + + def forward(self, x): + feat = self.base(x) + feat = self.bottleneck(feat) + global_feat = F.avg_pool2d(feat, feat.shape[2:]) # (b, 2048, 1, 1) + global_feat = global_feat.view(global_feat.shape[0], -1) + if self.training and self.num_classes is not None: + cls_score = self.classifier(global_feat) + return cls_score, global_feat + else: + return global_feat + + def optim_policy(self): + base_param_group = self.base.parameters() + clf_param_group = self.classifier.parameters() + return [ + {'params': base_param_group, 'lr_multi': 0.1}, + {'params': clf_param_group} + ] + + diff --git a/scripts/train_classification.sh b/scripts/train_classification.sh new file mode 100644 index 0000000..ec7467d --- /dev/null +++ b/scripts/train_classification.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +python3 ../main_sk_image_model.py train --save_dir='/DATA/pytorch-ckpt/market1501' diff --git a/trainers/__init__.py b/trainers/__init__.py new file mode 100644 index 0000000..4af45ee --- /dev/null +++ b/trainers/__init__.py @@ -0,0 +1,13 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from .evaluator import ResNetEvaluator +from .trainer import ResNetClsTrainer, ResNetTriTrainer, ResNetClsTriTrainer diff --git a/trainers/evaluator.py b/trainers/evaluator.py new file mode 100644 index 0000000..e8f6c63 --- /dev/null +++ b/trainers/evaluator.py @@ -0,0 +1,84 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np +import torch + +from bases.base_evaluator import BaseEvaluator + + +class ResNetEvaluator(BaseEvaluator): + def __init__(self, model): + super().__init__(model) + + def _parse_data(self, inputs): + imgs, pids, camids = inputs + return imgs.cuda(), pids, camids + + def _forward(self, inputs): + with torch.no_grad(): + feature = self.model(inputs) + return feature.cpu() + + def eval_func(self, distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): + """Evaluation with market1501 metric + Key: for each query identity, its gallery images from the same camera view are discarded. + """ + num_q, num_g = distmat.shape + if num_g < max_rank: + max_rank = num_g + print("Note: number of gallery samples is quite small, got {}".format(num_g)) + indices = np.argsort(distmat, axis=1) + matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) + + # compute cmc curve for each query + all_cmc = [] + all_AP = [] + num_valid_q = 0. # number of valid query + for q_idx in range(num_q): + # get query pid and camid + q_pid = q_pids[q_idx] + q_camid = q_camids[q_idx] + + # remove gallery samples that have the same pid and camid with query + order = indices[q_idx] + remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) + keep = np.invert(remove) + + # compute cmc curve + # binary vector, positions with value 1 are correct matches + orig_cmc = matches[q_idx][keep] + if not np.any(orig_cmc): + # this condition is true when query identity does not appear in gallery + continue + + cmc = orig_cmc.cumsum() + cmc[cmc > 1] = 1 + + all_cmc.append(cmc[:max_rank]) + num_valid_q += 1. + + # compute average precision + # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision + num_rel = orig_cmc.sum() + tmp_cmc = orig_cmc.cumsum() + tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] + tmp_cmc = np.asarray(tmp_cmc) * orig_cmc + AP = tmp_cmc.sum() / num_rel + all_AP.append(AP) + + assert num_valid_q > 0, "Error: all query identities do not appear in gallery" + + all_cmc = np.asarray(all_cmc).astype(np.float32) + all_cmc = all_cmc.sum(0) / num_valid_q + mAP = np.mean(all_AP) + + return all_cmc, mAP diff --git a/trainers/trainer.py b/trainers/trainer.py new file mode 100644 index 0000000..d1b68e0 --- /dev/null +++ b/trainers/trainer.py @@ -0,0 +1,54 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from bases.base_trainer import BaseTrainer + + +class ResNetClsTrainer(BaseTrainer): + def __init__(self, model, criterion, tb_writer): + super().__init__(model, criterion, tb_writer) + + def _parse_data(self, inputs): + imgs, pids, _ = inputs + return imgs.cuda(), pids.cuda() + + def _forward(self, inputs, targets): + cls_score, _ = self.model(inputs) + loss = self.criterion(cls_score, targets) + return loss + + +class ResNetTriTrainer(BaseTrainer): + def __init__(self, model, criterion, tb_writer): + super().__init__(model, criterion, tb_writer) + + def _parse_data(self, inputs): + imgs, pids, _ = inputs + return imgs.cuda(), pids.cuda() + + def _forward(self, inputs, targets): + feat = self.model(inputs) + loss = self.criterion(feat, targets) + return loss + + +class ResNetClsTriTrainer(BaseTrainer): + def __init__(self, model, criterion, tb_writer): + super().__init__(model, criterion, tb_writer) + + def _parse_data(self, inputs): + imgs, pids, _ = inputs + return imgs.cuda(), pids.cuda() + + def _forward(self, inputs, targets): + cls_score, feat = self.model(inputs) + loss = self.criterion(cls_score, feat, targets) + return loss \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..4ec04de --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,11 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + diff --git a/utils/loss.py b/utils/loss.py new file mode 100644 index 0000000..33d2c5e --- /dev/null +++ b/utils/loss.py @@ -0,0 +1,154 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import torch +from torch import nn +import torch.nn.functional as F + + +def normalize(x, axis=-1): + """Normalizing to unit length along the specified dimension. + Args: + x: pytorch Variable + Returns: + x: pytorch Variable, same shape as input + """ + x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) + return x + + +def euclidean_dist(x, y): + """ + Args: + x: pytorch Variable, with shape [m, d] + y: pytorch Variable, with shape [n, d] + Returns: + dist: pytorch Variable, with shape [m, n] + """ + m, n = x.size(0), y.size(0) + xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) + yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() + dist = xx + yy + dist.addmm_(1, -2, x, y.t()) + dist = dist.clamp(min=1e-12).sqrt() # for numerical stability + return dist + + +def hard_example_mining(dist_mat, labels, return_inds=False): + """For each anchor, find the hardest positive and negative sample. + Args: + dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] + labels: pytorch LongTensor, with shape [N] + return_inds: whether to return the indices. Save time if `False`(?) + Returns: + dist_ap: pytorch Variable, distance(anchor, positive); shape [N] + dist_an: pytorch Variable, distance(anchor, negative); shape [N] + p_inds: pytorch LongTensor, with shape [N]; + indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 + n_inds: pytorch LongTensor, with shape [N]; + indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 + NOTE: Only consider the case in which all labels have same num of samples, + thus we can cope with all anchors in parallel. + """ + + assert len(dist_mat.size()) == 2 + assert dist_mat.size(0) == dist_mat.size(1) + N = dist_mat.size(0) + + # shape [N, N] + is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) + is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) + + # `dist_ap` means distance(anchor, positive) + # both `dist_ap` and `relative_p_inds` with shape [N, 1] + dist_ap, relative_p_inds = torch.max( + dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) + # `dist_an` means distance(anchor, negative) + # both `dist_an` and `relative_n_inds` with shape [N, 1] + dist_an, relative_n_inds = torch.min( + dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) + # shape [N] + dist_ap = dist_ap.squeeze(1) + dist_an = dist_an.squeeze(1) + + if return_inds: + # shape [N, N] + ind = (labels.new().resize_as_(labels) + .copy_(torch.arange(0, N).long()) + .unsqueeze(0).expand(N, N)) + # shape [N, 1] + p_inds = torch.gather( + ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) + n_inds = torch.gather( + ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) + # shape [N] + p_inds = p_inds.squeeze(1) + n_inds = n_inds.squeeze(1) + return dist_ap, dist_an, p_inds, n_inds + + return dist_ap, dist_an + + +class TripletLoss(object): + """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). + Related Triplet Loss theory can be found in paper 'In Defense of the Triplet + Loss for Person Re-Identification'.""" + + def __init__(self, margin=None): + self.margin = margin + if margin is not None: + self.ranking_loss = nn.MarginRankingLoss(margin=margin) + else: + self.ranking_loss = nn.SoftMarginLoss() + + def __call__(self, global_feat, labels, normalize_feature=False): + if normalize_feature: + global_feat = normalize(global_feat, axis=-1) + dist_mat = euclidean_dist(global_feat, global_feat) + dist_ap, dist_an = hard_example_mining( + dist_mat, labels) + y = dist_an.new().resize_as_(dist_an).fill_(1) + if self.margin is not None: + loss = self.ranking_loss(dist_an, dist_ap, y) + else: + loss = self.ranking_loss(dist_an - dist_ap, y) + return loss, dist_ap, dist_an + + +class CrossEntropyLabelSmooth(nn.Module): + """Cross entropy loss with label smoothing regularizer. + Reference: + Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. + Equation: y = (1 - epsilon) * y + epsilon / K. + Args: + num_classes (int): number of classes. + epsilon (float): weight. + """ + + def __init__(self, num_classes, epsilon=0.1, use_gpu=True): + super(CrossEntropyLabelSmooth, self).__init__() + self.num_classes = num_classes + self.epsilon = epsilon + self.use_gpu = use_gpu + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, targets): + """ + Args: + inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) + targets: ground truth labels with shape (num_classes) + """ + log_probs = self.logsoftmax(inputs) + targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) + if self.use_gpu: targets = targets.cuda() + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes + loss = (- targets * log_probs).mean(0).sum() + return loss diff --git a/utils/meters.py b/utils/meters.py new file mode 100644 index 0000000..d8bec7b --- /dev/null +++ b/utils/meters.py @@ -0,0 +1,50 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import math + +import numpy as np + + +class AverageMeter(object): + def __init__(self): + self.n = 0 + self.sum = 0.0 + self.var = 0.0 + self.val = 0.0 + self.mean = np.nan + self.std = np.nan + + def update(self, value, n=1): + self.val = value + self.sum += value + self.var += value * value + self.n += n + + if self.n == 0: + self.mean, self.std = np.nan, np.nan + elif self.n == 1: + self.mean, self.std = self.sum, np.inf + else: + self.mean = self.sum / self.n + self.std = math.sqrt( + (self.var - self.n * self.mean * self.mean) / (self.n - 1.0)) + + def value(self): + return self.mean, self.std + + def reset(self): + self.n = 0 + self.sum = 0.0 + self.var = 0.0 + self.val = 0.0 + self.mean = np.nan + self.std = np.nan diff --git a/utils/serialization.py b/utils/serialization.py new file mode 100644 index 0000000..7f55c75 --- /dev/null +++ b/utils/serialization.py @@ -0,0 +1,74 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import errno +import os +import shutil +import sys + +import os.path as osp +import torch + + +class Logger(object): + """ + Write console output to external text file. + Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. + """ + + def __init__(self, fpath=None): + self.console = sys.stdout + self.file = None + if fpath is not None: + mkdir_if_missing(os.path.dirname(fpath)) + self.file = open(fpath, 'w') + + def __del__(self): + self.close() + + def __enter__(self): + pass + + def __exit__(self, *args): + self.close() + + def write(self, msg): + self.console.write(msg) + if self.file is not None: + self.file.write(msg) + + def flush(self): + self.console.flush() + if self.file is not None: + self.file.flush() + os.fsync(self.file.fileno()) + + def close(self): + self.console.close() + if self.file is not None: + self.file.close() + + +def mkdir_if_missing(dir_path): + try: + os.makedirs(dir_path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def save_checkpoint(state, is_best, save_dir, filename='checkpoint.pth.tar'): + fpath = '_'.join((str(state['epoch']), filename)) + fpath = osp.join(save_dir, fpath) + mkdir_if_missing(save_dir) + torch.save(state, fpath) + if is_best: + shutil.copy(fpath, osp.join(save_dir, 'model_best.pth.tar')) diff --git a/utils/transforms.py b/utils/transforms.py new file mode 100644 index 0000000..e5ff683 --- /dev/null +++ b/utils/transforms.py @@ -0,0 +1,80 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import random + +from PIL import Image +from torchvision import transforms as T + + +class Random2DTranslation(object): + """ + With a probability, first increase image size to (1 + 1/8), and then perform random crop. + + Args: + height (int): target height. + width (int): target width. + p (float): probability of performing this transformation. Default: 0.5. + """ + + def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): + self.height = height + self.width = width + self.p = p + self.interpolation = interpolation + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped. + + Returns: + PIL Image: Cropped image. + """ + if random.random() < self.p: + return img.resize((self.width, self.height), self.interpolation) + new_width, new_height = int( + round(self.width * 1.125)), int(round(self.height * 1.125)) + resized_img = img.resize((new_width, new_height), self.interpolation) + x_maxrange = new_width - self.width + y_maxrange = new_height - self.height + x1 = int(round(random.uniform(0, x_maxrange))) + y1 = int(round(random.uniform(0, y_maxrange))) + croped_img = resized_img.crop( + (x1, y1, x1 + self.width, y1 + self.height)) + return croped_img + + +class TrainTransform(object): + def __init__(self, h, w): + self.h = h + self.w = w + + def __call__(self, x): + x = Random2DTranslation(self.h, self.w)(x) + x = T.RandomHorizontalFlip()(x) + x = T.ToTensor()(x) + x = T.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])(x) + return x + + +class TestTransform(object): + def __init__(self, h, w): + self.h = h + self.w = w + + def __call__(self, x=None): + x = T.Resize((self.h, self.w))(x) + x = T.ToTensor()(x) + x = T.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])(x) + return x diff --git a/utils/validation_metrics.py b/utils/validation_metrics.py new file mode 100644 index 0000000..73c726b --- /dev/null +++ b/utils/validation_metrics.py @@ -0,0 +1,25 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: xyliao1993@qq.com +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +def accuracy(score, target, topk=(1,)): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = score.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + ret = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) + ret.append(correct_k.mul_(1. / batch_size)) + return ret