mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
Merge pull request #17 from luzai/master
add Cython version eval_market1501_wrap
This commit is contained in:
commit
f081c92741
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,6 +3,8 @@ data/
|
|||||||
log/
|
log/
|
||||||
saved-models/
|
saved-models/
|
||||||
train_imagenet.py
|
train_imagenet.py
|
||||||
|
eval_lib/eval.c
|
||||||
|
.idea
|
||||||
|
|
||||||
# OS X
|
# OS X
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
@ -1831,7 +1831,7 @@ __vid_factory = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def get_names():
|
def get_names():
|
||||||
return __img_factory.keys() + __vid_factory.keys()
|
return list(__img_factory.keys()) + list(__vid_factory.keys())
|
||||||
|
|
||||||
def init_img_dataset(name, **kwargs):
|
def init_img_dataset(name, **kwargs):
|
||||||
if name not in __img_factory.keys():
|
if name not in __img_factory.keys():
|
||||||
|
6
eval_lib/Makefile
Normal file
6
eval_lib/Makefile
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
all:
|
||||||
|
python setup.py build_ext --inplace
|
||||||
|
rm -rf build
|
||||||
|
clean:
|
||||||
|
rm -rf build
|
||||||
|
rm -f eval.c *.so
|
0
eval_lib/__init__.py
Normal file
0
eval_lib/__init__.py
Normal file
133
eval_lib/eval.pyx
Normal file
133
eval_lib/eval.pyx
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True
|
||||||
|
|
||||||
|
cimport cython
|
||||||
|
cimport numpy as np
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
cpdef eval_market1501_wrap(distmat,
|
||||||
|
q_pids,
|
||||||
|
g_pids,
|
||||||
|
q_camids,
|
||||||
|
g_camids,
|
||||||
|
max_rank):
|
||||||
|
distmat = np.asarray(distmat,dtype=np.float32)
|
||||||
|
q_pids = np.asarray(q_pids, dtype=np.int64)
|
||||||
|
g_pids = np.asarray(g_pids , dtype=np.int64)
|
||||||
|
q_camids=np.asarray(q_camids,dtype=np.int64)
|
||||||
|
g_camids=np.asarray(g_camids, dtype=np.int64)
|
||||||
|
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||||
|
|
||||||
|
cpdef eval_market1501(
|
||||||
|
float[:,:] distmat,
|
||||||
|
long[:] q_pids,
|
||||||
|
long[:] g_pids,
|
||||||
|
long[:] q_camids,
|
||||||
|
long[:] g_camids,
|
||||||
|
long max_rank,
|
||||||
|
):
|
||||||
|
# return 0,0
|
||||||
|
cdef:
|
||||||
|
long num_q = distmat.shape[0], num_g = distmat.shape[1]
|
||||||
|
|
||||||
|
if num_g < max_rank:
|
||||||
|
max_rank = num_g
|
||||||
|
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||||
|
|
||||||
|
cdef:
|
||||||
|
long[:,:] indices = np.argsort(distmat, axis=1)
|
||||||
|
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||||
|
float[:,:] all_cmc = np.zeros((num_q,max_rank),dtype=np.float32)
|
||||||
|
float[:] all_AP = np.zeros(num_q,dtype=np.float32)
|
||||||
|
|
||||||
|
long q_pid, q_camid
|
||||||
|
long[:] order=np.zeros(num_g,dtype=np.int64), keep =np.zeros(num_g,dtype=np.int64)
|
||||||
|
|
||||||
|
long num_valid_q = 0, q_idx, idx
|
||||||
|
# long[:] orig_cmc=np.zeros(num_g,dtype=np.int64)
|
||||||
|
float[:] orig_cmc=np.zeros(num_g,dtype=np.float32)
|
||||||
|
float[:] cmc=np.zeros(num_g,dtype=np.float32), tmp_cmc=np.zeros(num_g,dtype=np.float32)
|
||||||
|
long num_orig_cmc=0
|
||||||
|
float num_rel=0.
|
||||||
|
float tmp_cmc_sum =0.
|
||||||
|
# num_orig_cmc is the valid size of orig_cmc, cmc and tmp_cmc
|
||||||
|
unsigned int orig_cmc_flag=0
|
||||||
|
|
||||||
|
for q_idx in range(num_q):
|
||||||
|
# get query pid and camid
|
||||||
|
q_pid = q_pids[q_idx]
|
||||||
|
q_camid = q_camids[q_idx]
|
||||||
|
# remove gallery samples that have the same pid and camid with query
|
||||||
|
order = indices[q_idx]
|
||||||
|
for idx in range(num_g):
|
||||||
|
keep[idx] = ( g_pids[order[idx]] !=q_pid) or (g_camids[order[idx]]!=q_camid )
|
||||||
|
# compute cmc curve
|
||||||
|
num_orig_cmc=0
|
||||||
|
orig_cmc_flag=0
|
||||||
|
for idx in range(num_g):
|
||||||
|
if keep[idx]:
|
||||||
|
orig_cmc[num_orig_cmc] = matches[q_idx][idx]
|
||||||
|
num_orig_cmc +=1
|
||||||
|
if matches[q_idx][idx]>1e-31:
|
||||||
|
orig_cmc_flag=1
|
||||||
|
if not orig_cmc_flag:
|
||||||
|
all_AP[q_idx]=-1
|
||||||
|
# print('continue ', q_idx)
|
||||||
|
# this condition is true when query identity does not appear in gallery
|
||||||
|
continue
|
||||||
|
my_cusum(orig_cmc,cmc,num_orig_cmc)
|
||||||
|
for idx in range(num_orig_cmc):
|
||||||
|
if cmc[idx] >1:
|
||||||
|
cmc[idx] =1
|
||||||
|
all_cmc[q_idx] = cmc[:max_rank]
|
||||||
|
num_valid_q+=1
|
||||||
|
|
||||||
|
# print('ori cmc', np.asarray(orig_cmc).tolist())
|
||||||
|
# print('cmc', np.asarray(cmc).tolist())
|
||||||
|
# compute average precision
|
||||||
|
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||||
|
num_rel = 0.
|
||||||
|
for idx in range(num_orig_cmc):
|
||||||
|
num_rel += orig_cmc[idx]
|
||||||
|
my_cusum( orig_cmc, tmp_cmc, num_orig_cmc)
|
||||||
|
for idx in range(num_orig_cmc):
|
||||||
|
tmp_cmc[idx] = tmp_cmc[idx] / (idx+1.) * orig_cmc[idx]
|
||||||
|
# print('tmp_cmc', np.asarray(tmp_cmc).tolist())
|
||||||
|
|
||||||
|
tmp_cmc_sum=my_sum(tmp_cmc,num_orig_cmc)
|
||||||
|
all_AP[q_idx] = tmp_cmc_sum / num_rel
|
||||||
|
# print('final',tmp_cmc_sum, num_rel, tmp_cmc_sum / num_rel,'\n')
|
||||||
|
|
||||||
|
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||||
|
# print_dbg('all ap', all_AP)
|
||||||
|
# print_dbg('all cmc', all_cmc)
|
||||||
|
all_AP_np = np.asarray(all_AP)
|
||||||
|
all_AP_np[np.isclose(all_AP,-1)] = np.nan
|
||||||
|
return np.asarray(all_cmc).astype(np.float32).sum(axis=0) / num_valid_q, \
|
||||||
|
np.nanmean(all_AP_np)
|
||||||
|
|
||||||
|
def print_dbg(msg, val):
|
||||||
|
print(msg, np.asarray(val))
|
||||||
|
|
||||||
|
cpdef void my_cusum(
|
||||||
|
cython.numeric[:] src,
|
||||||
|
cython.numeric[:] dst,
|
||||||
|
long size
|
||||||
|
) nogil:
|
||||||
|
cdef:
|
||||||
|
long idx
|
||||||
|
for idx in range(size):
|
||||||
|
if idx==0:
|
||||||
|
dst[idx] = src[idx]
|
||||||
|
else:
|
||||||
|
dst[idx] = src[idx]+dst[idx-1]
|
||||||
|
|
||||||
|
cpdef cython.numeric my_sum(
|
||||||
|
cython.numeric[:] src,
|
||||||
|
long size
|
||||||
|
) nogil:
|
||||||
|
cdef:
|
||||||
|
long idx
|
||||||
|
cython.numeric ttl=0
|
||||||
|
for idx in range(size):
|
||||||
|
ttl+=src[idx]
|
||||||
|
return ttl
|
23
eval_lib/setup.py
Normal file
23
eval_lib/setup.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import numpy as np
|
||||||
|
from distutils.core import setup
|
||||||
|
from distutils.extension import Extension
|
||||||
|
from Cython.Distutils import build_ext
|
||||||
|
|
||||||
|
try:
|
||||||
|
numpy_include = np.get_include()
|
||||||
|
except AttributeError:
|
||||||
|
numpy_include = np.get_numpy_include()
|
||||||
|
print(numpy_include)
|
||||||
|
|
||||||
|
ext_modules = [Extension("cython_eval",
|
||||||
|
["eval.pyx"],
|
||||||
|
libraries=["m"],
|
||||||
|
include_dirs=[numpy_include],
|
||||||
|
extra_compile_args=["-ffast-math", "-Wno-cpp", "-Wno-unused-function"]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='eval_lib',
|
||||||
|
cmdclass={"build_ext": build_ext},
|
||||||
|
ext_modules=ext_modules)
|
37
eval_lib/test.py
Normal file
37
eval_lib/test.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
import sys, os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.dirname(os.path.abspath(__file__)) + '/..'
|
||||||
|
)
|
||||||
|
from eval_lib.cython_eval import eval_market1501_wrap
|
||||||
|
from eval_metrics import eval_market1501
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
num_q = 300
|
||||||
|
num_g = 1500
|
||||||
|
|
||||||
|
distmat = np.random.rand(num_q, num_g) * 20
|
||||||
|
q_pids = np.random.randint(0, num_q, size=num_q)
|
||||||
|
g_pids = np.random.randint(0, num_g, size=num_g)
|
||||||
|
q_camids = np.random.randint(0, 5, size=num_q)
|
||||||
|
g_camids = np.random.randint(0, 5, size=num_g)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
mAP, cmc = eval_market1501_wrap(distmat,
|
||||||
|
q_pids,
|
||||||
|
g_pids,
|
||||||
|
q_camids,
|
||||||
|
g_camids, 10)
|
||||||
|
toc = time.time()
|
||||||
|
print('consume time {} \n mAP is {} \n cmc is {}'.format(toc - tic, mAP, cmc))
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
cmc, mAP = eval_market1501(distmat,
|
||||||
|
q_pids,
|
||||||
|
g_pids,
|
||||||
|
q_camids,
|
||||||
|
g_camids, 10)
|
||||||
|
toc = time.time()
|
||||||
|
print('consume time {} \n mAP is {} \n cmc is {}'.format(toc - tic, mAP, cmc))
|
@ -3,6 +3,7 @@ import numpy as np
|
|||||||
import copy
|
import copy
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import sys
|
import sys
|
||||||
|
from eval_lib.cython_eval import eval_market1501_wrap
|
||||||
|
|
||||||
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):
|
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):
|
||||||
"""Evaluation with cuhk03 metric
|
"""Evaluation with cuhk03 metric
|
||||||
@ -126,8 +127,11 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||||||
|
|
||||||
return all_cmc, mAP
|
return all_cmc, mAP
|
||||||
|
|
||||||
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False):
|
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False, use_cython=False):
|
||||||
if use_metric_cuhk03:
|
if use_metric_cuhk03:
|
||||||
return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||||
else:
|
else:
|
||||||
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
if not use_cython:
|
||||||
|
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||||
|
else:
|
||||||
|
return eval_market1501_wrap(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||||
|
@ -1484,4 +1484,4 @@ class ResNeXt101_64x4d(nn.Module):
|
|||||||
elif self.loss == {'ring'}:
|
elif self.loss == {'ring'}:
|
||||||
return y, x
|
return y, x
|
||||||
else:
|
else:
|
||||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||||
|
328
train_img_model_xent_dev.py
Executable file
328
train_img_model_xent_dev.py
Executable file
@ -0,0 +1,328 @@
|
|||||||
|
from __future__ import print_function, absolute_import
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
import argparse
|
||||||
|
import os.path as osp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
|
||||||
|
import data_manager
|
||||||
|
from dataset_loader import ImageDataset
|
||||||
|
import transforms as T
|
||||||
|
import models
|
||||||
|
from losses import CrossEntropyLabelSmooth, DeepSupervision
|
||||||
|
from utils import AverageMeter, Logger, save_checkpoint
|
||||||
|
from eval_metrics import evaluate
|
||||||
|
from optimizers import init_optim
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Train image model with cross entropy loss')
|
||||||
|
# Datasets
|
||||||
|
parser.add_argument('--root', type=str, help="root path to data directory", default='/home/xinglu/.torch/data/')
|
||||||
|
parser.add_argument('-d', '--dataset', type=str, default='market1501',
|
||||||
|
choices=data_manager.get_names())
|
||||||
|
parser.add_argument('-j', '--workers', default=4, type=int,
|
||||||
|
help="number of data loading workers (default: 4)")
|
||||||
|
parser.add_argument('--height', type=int, default=256,
|
||||||
|
help="height of an image (default: 256)")
|
||||||
|
parser.add_argument('--width', type=int, default=128,
|
||||||
|
help="width of an image (default: 128)")
|
||||||
|
parser.add_argument('--split-id', type=int, default=0, help="split index")
|
||||||
|
# CUHK03-specific setting
|
||||||
|
parser.add_argument('--cuhk03-labeled', action='store_true',
|
||||||
|
help="whether to use labeled images, if false, detected images are used (default: False)")
|
||||||
|
parser.add_argument('--cuhk03-classic-split', action='store_true',
|
||||||
|
help="whether to use classic split by Li et al. CVPR'14 (default: False)")
|
||||||
|
parser.add_argument('--use-metric-cuhk03', action='store_true',
|
||||||
|
help="whether to use cuhk03-metric (default: False)")
|
||||||
|
# Optimization options
|
||||||
|
parser.add_argument('--optim', type=str, default='adam', help="optimization algorithm (see optimizers.py)")
|
||||||
|
parser.add_argument('--max-epoch', default=60, type=int,
|
||||||
|
help="maximum epochs to run")
|
||||||
|
parser.add_argument('--start-epoch', default=0, type=int,
|
||||||
|
help="manual epoch number (useful on restarts)")
|
||||||
|
parser.add_argument('--train-batch', default=32, type=int,
|
||||||
|
help="train batch size")
|
||||||
|
parser.add_argument('--test-batch', default=128, type=int, help="test batch size")
|
||||||
|
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
|
||||||
|
help="initial learning rate")
|
||||||
|
parser.add_argument('--stepsize', default=20, type=int,
|
||||||
|
help="stepsize to decay learning rate (>0 means this is enabled)")
|
||||||
|
parser.add_argument('--gamma', default=0.1, type=float,
|
||||||
|
help="learning rate decay")
|
||||||
|
parser.add_argument('--weight-decay', default=5e-04, type=float,
|
||||||
|
help="weight decay (default: 5e-04)")
|
||||||
|
# Architecture
|
||||||
|
parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())
|
||||||
|
# Miscs
|
||||||
|
parser.add_argument('--print-freq', type=int, default=10, help="print frequency")
|
||||||
|
parser.add_argument('--seed', type=int, default=1, help="manual seed")
|
||||||
|
parser.add_argument('--resume', type=str, default='./resnet50_xent_market1501.pth.tar', metavar='PATH')
|
||||||
|
parser.add_argument('--evaluate', action='store_true', help="evaluation only", default=True)
|
||||||
|
parser.add_argument('--eval-step', type=int, default=-1,
|
||||||
|
help="run evaluation for every N epochs (set to -1 to test after training)")
|
||||||
|
parser.add_argument('--save-dir', type=str, default='log')
|
||||||
|
parser.add_argument('--use-cpu', action='store_true', help="use cpu")
|
||||||
|
parser.add_argument('--gpu-devices', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
|
||||||
|
use_gpu = torch.cuda.is_available()
|
||||||
|
if args.use_cpu: use_gpu = False
|
||||||
|
|
||||||
|
if not args.evaluate:
|
||||||
|
sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
|
||||||
|
else:
|
||||||
|
sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
|
||||||
|
print("==========\nArgs:{}\n==========".format(args))
|
||||||
|
|
||||||
|
if use_gpu:
|
||||||
|
print("Currently using GPU {}".format(args.gpu_devices))
|
||||||
|
cudnn.benchmark = True
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
else:
|
||||||
|
print("Currently using CPU (GPU is highly recommended)")
|
||||||
|
|
||||||
|
print("Initializing dataset {}".format(args.dataset))
|
||||||
|
dataset = data_manager.init_img_dataset(
|
||||||
|
root=args.root, name=args.dataset, split_id=args.split_id,
|
||||||
|
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
|
||||||
|
)
|
||||||
|
|
||||||
|
transform_train = T.Compose([
|
||||||
|
T.Random2DTranslation(args.height, args.width),
|
||||||
|
T.RandomHorizontalFlip(),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
transform_test = T.Compose([
|
||||||
|
T.Resize((args.height, args.width)),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
pin_memory = True if use_gpu else False
|
||||||
|
|
||||||
|
trainloader = DataLoader(
|
||||||
|
ImageDataset(dataset.train, transform=transform_train),
|
||||||
|
batch_size=args.train_batch, shuffle=True, num_workers=args.workers,
|
||||||
|
pin_memory=pin_memory, drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
queryloader = DataLoader(
|
||||||
|
ImageDataset(dataset.query, transform=transform_test),
|
||||||
|
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
|
||||||
|
pin_memory=pin_memory, drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
galleryloader = DataLoader(
|
||||||
|
ImageDataset(dataset.gallery, transform=transform_test),
|
||||||
|
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
|
||||||
|
pin_memory=pin_memory, drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Initializing model: {}".format(args.arch))
|
||||||
|
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent'}, use_gpu=use_gpu)
|
||||||
|
print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters()) / 1000000.0))
|
||||||
|
|
||||||
|
criterion = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)
|
||||||
|
optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay)
|
||||||
|
if args.stepsize > 0:
|
||||||
|
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)
|
||||||
|
start_epoch = args.start_epoch
|
||||||
|
|
||||||
|
if args.resume:
|
||||||
|
print("Loading checkpoint from '{}'".format(args.resume))
|
||||||
|
# with open(args.resume, encoding='latin1') as f:
|
||||||
|
# import io
|
||||||
|
# buffer = io.BytesIO(f.read())
|
||||||
|
# checkpoint = torch.load(buffer)
|
||||||
|
checkpoint = torch.load(args.resume)
|
||||||
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
start_epoch = checkpoint['epoch']
|
||||||
|
|
||||||
|
if use_gpu:
|
||||||
|
model = nn.DataParallel(model).cuda()
|
||||||
|
|
||||||
|
if args.evaluate:
|
||||||
|
print("Evaluate only")
|
||||||
|
test(model, queryloader, galleryloader, use_gpu)
|
||||||
|
return
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
train_time = 0
|
||||||
|
best_rank1 = -np.inf
|
||||||
|
best_epoch = 0
|
||||||
|
print("==> Start training")
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, args.max_epoch):
|
||||||
|
start_train_time = time.time()
|
||||||
|
train(epoch, model, criterion, optimizer, trainloader, use_gpu)
|
||||||
|
train_time += round(time.time() - start_train_time)
|
||||||
|
|
||||||
|
if args.stepsize > 0: scheduler.step()
|
||||||
|
|
||||||
|
if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
||||||
|
print("==> Test")
|
||||||
|
rank1 = test(model, queryloader, galleryloader, use_gpu)
|
||||||
|
is_best = rank1 > best_rank1
|
||||||
|
if is_best:
|
||||||
|
best_rank1 = rank1
|
||||||
|
best_epoch = epoch + 1
|
||||||
|
|
||||||
|
if use_gpu:
|
||||||
|
state_dict = model.module.state_dict()
|
||||||
|
else:
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
save_checkpoint({
|
||||||
|
'state_dict': state_dict,
|
||||||
|
'rank1': rank1,
|
||||||
|
'epoch': epoch,
|
||||||
|
}, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
|
||||||
|
|
||||||
|
print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))
|
||||||
|
|
||||||
|
elapsed = round(time.time() - start_time)
|
||||||
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||||
|
train_time = str(datetime.timedelta(seconds=train_time))
|
||||||
|
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
|
||||||
|
|
||||||
|
|
||||||
|
def train(epoch, model, criterion, optimizer, trainloader, use_gpu):
|
||||||
|
losses = AverageMeter()
|
||||||
|
batch_time = AverageMeter()
|
||||||
|
data_time = AverageMeter()
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
|
||||||
|
# measure data loading time
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
|
||||||
|
if use_gpu:
|
||||||
|
imgs, pids = imgs.cuda(), pids.cuda()
|
||||||
|
outputs = model(imgs)
|
||||||
|
if isinstance(outputs, tuple):
|
||||||
|
loss = DeepSupervision(criterion, outputs, pids)
|
||||||
|
else:
|
||||||
|
loss = criterion(outputs, pids)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
losses.update(loss.item(), pids.size(0))
|
||||||
|
|
||||||
|
if (batch_idx + 1) % args.print_freq == 0:
|
||||||
|
print('Epoch: [{0}][{1}/{2}]\t'
|
||||||
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||||
|
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
||||||
|
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
|
||||||
|
epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time,
|
||||||
|
data_time=data_time, loss=losses))
|
||||||
|
|
||||||
|
|
||||||
|
def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]):
|
||||||
|
batch_time = AverageMeter()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
qf, q_pids, q_camids = [], [], []
|
||||||
|
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
|
||||||
|
if use_gpu: imgs = imgs.cuda()
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
features = model(imgs)
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
|
||||||
|
features = features.data.cpu()
|
||||||
|
qf.append(features)
|
||||||
|
q_pids.extend(pids)
|
||||||
|
q_camids.extend(camids)
|
||||||
|
qf = torch.cat(qf, 0)
|
||||||
|
q_pids = np.asarray(q_pids)
|
||||||
|
q_camids = np.asarray(q_camids)
|
||||||
|
|
||||||
|
print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
|
||||||
|
|
||||||
|
gf, g_pids, g_camids = [], [], []
|
||||||
|
end = time.time()
|
||||||
|
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
|
||||||
|
if use_gpu: imgs = imgs.cuda()
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
features = model(imgs)
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
|
||||||
|
features = features.data.cpu()
|
||||||
|
gf.append(features)
|
||||||
|
g_pids.extend(pids)
|
||||||
|
g_camids.extend(camids)
|
||||||
|
gf = torch.cat(gf, 0)
|
||||||
|
g_pids = np.asarray(g_pids)
|
||||||
|
g_camids = np.asarray(g_camids)
|
||||||
|
|
||||||
|
print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
|
||||||
|
|
||||||
|
print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, args.test_batch))
|
||||||
|
|
||||||
|
m, n = qf.size(0), gf.size(0)
|
||||||
|
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||||
|
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||||
|
distmat.addmm_(1, -2, qf, gf.t())
|
||||||
|
distmat = distmat.numpy()
|
||||||
|
# qf = qf.numpy()
|
||||||
|
# gf = gf.numpy()
|
||||||
|
# from scipy.spatial.distance import cdist
|
||||||
|
# distmat = cdist(qf, gf)
|
||||||
|
|
||||||
|
print("Computing CMC and mAP, origin eval")
|
||||||
|
tic = time.time()
|
||||||
|
cmc0, mAP0 = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, use_metric_cuhk03=args.use_metric_cuhk03,
|
||||||
|
use_cython=False)
|
||||||
|
print("Results ----------")
|
||||||
|
print("mAP: {:.1%}".format(mAP0))
|
||||||
|
print("CMC curve")
|
||||||
|
for r in ranks:
|
||||||
|
print("Rank-{:<3}: {:.1%}".format(r, cmc0[r - 1]))
|
||||||
|
print("consume {} s".format(time.time() - tic))
|
||||||
|
print("------------------\n")
|
||||||
|
|
||||||
|
print("Computing CMC and mAP, use cython")
|
||||||
|
tic = time.time()
|
||||||
|
cmc1, mAP1 = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, use_metric_cuhk03=args.use_metric_cuhk03,
|
||||||
|
use_cython=True)
|
||||||
|
print("Results ----------")
|
||||||
|
print("mAP: {:.1%}".format(mAP1))
|
||||||
|
print("CMC curve")
|
||||||
|
for r in ranks:
|
||||||
|
print("Rank-{:<3}: {:.1%}".format(r, cmc1[r - 1]))
|
||||||
|
print("consume {} s".format(time.time() - tic))
|
||||||
|
print("------------------\n")
|
||||||
|
|
||||||
|
print('absolute difference between two version mAP is {}, '
|
||||||
|
'relative difference between two version mAP is {} %'.format(
|
||||||
|
np.abs(mAP0 - mAP1), np.abs(mAP0 - mAP1) * 100. / mAP0
|
||||||
|
))
|
||||||
|
|
||||||
|
return cmc1[0]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user