mirror of https://github.com/JDAI-CV/fast-reid.git
Finish refactor code by fastai
parent
852bb8ae8b
commit
29630d1290
config
data
datasets
samplers
transforms
engine
layers
|
@ -1,6 +1,6 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
|
|
@ -26,9 +26,9 @@ _C.MODEL.PRETRAIN_PATH = ''
|
|||
# -----------------------------------------------------------------------------
|
||||
_C.INPUT = CN()
|
||||
# Size of the image during training
|
||||
_C.INPUT.SIZE_TRAIN = [384, 128]
|
||||
_C.INPUT.SIZE_TRAIN = [256, 128]
|
||||
# Size of the image during test
|
||||
_C.INPUT.SIZE_TEST = [384, 128]
|
||||
_C.INPUT.SIZE_TEST = [256, 128]
|
||||
# Random probability for image horizontal flip
|
||||
_C.INPUT.PROB = 0.5
|
||||
# Values to be used for image normalization
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
MODEL:
|
||||
PRETRAIN_PATH: '/export/home/lxy/.torch/models/resnet50-19c8e357.pth'
|
||||
PRETRAIN_PATH: 'home/user01/.torch/models/resnet50-19c8e357.pth'
|
||||
|
||||
|
||||
INPUT:
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
MODEL:
|
||||
PRETRAIN_PATH: '/export/home/lxy/.torch/models/resnet50-19c8e357.pth'
|
||||
PRETRAIN_PATH: '/home/user01/.torch/models/resnet50-19c8e357.pth'
|
||||
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [384, 128]
|
||||
SIZE_TEST: [384, 128]
|
||||
SIZE_TRAIN: [256, 128]
|
||||
SIZE_TEST: [256, 128]
|
||||
PROB: 0.5 # random horizontal flip
|
||||
PADDING: 10
|
||||
|
||||
|
@ -14,13 +14,11 @@ DATASETS:
|
|||
DATALOADER:
|
||||
SAMPLER: 'softmax_triplet'
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
OPTIMIZER_NAME: 'Adam'
|
||||
MAX_EPOCHS: 120
|
||||
BASE_LR: 0.00035
|
||||
BIAS_LR_FACTOR: 1
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
@ -30,16 +28,13 @@ SOLVER:
|
|||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 10
|
||||
WARMUP_METHOD: 'linear'
|
||||
|
||||
CHECKPOINT_PERIOD: 40
|
||||
LOG_PERIOD: 100
|
||||
EVAL_PERIOD: 40
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 256
|
||||
IMS_PER_BATCH: 512
|
||||
WEIGHT: "path"
|
||||
|
||||
OUTPUT_DIR: "/export/home/lxy/CHECKPOINTS/reid/market1501/softmax_triplet_bs128_384x128"
|
||||
OUTPUT_DIR: "/home/user01/l1aoxingyu/reid_baseline/logs/co-train/softmax_triplet_bs64_256x128"
|
||||
|
||||
|
||||
|
|
|
@ -4,4 +4,4 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import make_data_loader
|
||||
from .build import get_data_bunch
|
||||
|
|
|
@ -1,44 +1,69 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
import glob
|
||||
|
||||
from .collate_batch import train_collate_fn, val_collate_fn
|
||||
from .datasets import init_dataset, ImageDataset
|
||||
from fastai.vision import *
|
||||
from .transforms import RandomErasing
|
||||
from .samplers import RandomIdentitySampler
|
||||
from .transforms import build_transforms
|
||||
|
||||
|
||||
def make_data_loader(cfg):
|
||||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
val_transforms = build_transforms(cfg, is_train=False)
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
if len(cfg.DATASETS.NAMES) == 1:
|
||||
dataset = init_dataset(cfg.DATASETS.NAMES)
|
||||
else:
|
||||
# TODO: add multi dataset to train
|
||||
dataset = init_dataset(cfg.DATASETS.NAMES)
|
||||
|
||||
num_classes = dataset.num_train_pids
|
||||
train_set = ImageDataset(dataset.train, train_transforms)
|
||||
if cfg.DATALOADER.SAMPLER == 'softmax':
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
|
||||
collate_fn=train_collate_fn
|
||||
)
|
||||
else:
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
|
||||
sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
|
||||
num_workers=num_workers, collate_fn=train_collate_fn
|
||||
)
|
||||
|
||||
val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
|
||||
val_loader = DataLoader(
|
||||
val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
|
||||
collate_fn=val_collate_fn
|
||||
def get_data_bunch(cfg):
|
||||
ds_tfms = (
|
||||
[flip_lr(p=0.5),
|
||||
*rand_pad(padding=cfg.INPUT.PADDING, size=cfg.INPUT.SIZE_TRAIN, mode='zeros'),
|
||||
RandomErasing()
|
||||
],
|
||||
None
|
||||
)
|
||||
return train_loader, val_loader, len(dataset.query), num_classes
|
||||
|
||||
def _process_dir(dir_path):
|
||||
img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
pid_container = set()
|
||||
v_paths = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1: continue # junk images are just ignored
|
||||
pid_container.add(pid)
|
||||
v_paths.append([img_path, pid, camid])
|
||||
return v_paths
|
||||
|
||||
market_train_path = 'datasets/Market-1501-v15.09.15/bounding_box_train'
|
||||
duke_train_path = 'datasets/DukeMTMC-reID/bounding_box_train'
|
||||
cuhk03_train_path = ''
|
||||
query_path = 'datasets/Market-1501-v15.09.15/query'
|
||||
gallery_path = 'datasets/Market-1501-v15.09.15/bounding_box_test'
|
||||
|
||||
train_img_names = _process_dir(market_train_path) + _process_dir(duke_train_path)
|
||||
train_names = [i[0] for i in train_img_names]
|
||||
|
||||
query_names = _process_dir(query_path)
|
||||
gallery_names = _process_dir(gallery_path)
|
||||
test_fnames = []
|
||||
test_labels = []
|
||||
for i in query_names+gallery_names:
|
||||
test_fnames.append(i[0])
|
||||
test_labels.append(i[1:])
|
||||
|
||||
def get_labels(file_path):
|
||||
""" Suitable for muilti-dataset training """
|
||||
prefix = file_path.split('/')[1]
|
||||
pat = re.compile(r'([-\d]+)_c(\d)')
|
||||
pid, _ = pat.search(file_path).groups()
|
||||
return prefix + '_' + pid
|
||||
|
||||
data_sampler = RandomIdentitySampler(train_names, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) \
|
||||
if cfg.DATALOADER.SAMPLER == 'softmax_triplet' else None
|
||||
data_bunch = ImageDataBunch.from_name_func('datasets', train_names, label_func=get_labels, valid_pct=0,
|
||||
size=(256, 128), ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
|
||||
val_bs=cfg.TEST.IMS_PER_BATCH,
|
||||
sampler=data_sampler)
|
||||
data_bunch.add_test(test_fnames)
|
||||
data_bunch.normalize(imagenet_stats)
|
||||
|
||||
return data_bunch, test_labels, len(query_names)
|
||||
|
|
|
@ -5,14 +5,15 @@
|
|||
"""
|
||||
|
||||
import torch
|
||||
from fastai.vision import *
|
||||
|
||||
|
||||
def train_collate_fn(batch):
|
||||
imgs, pids, _, _, = zip(*batch)
|
||||
pids = torch.tensor(pids, dtype=torch.int64)
|
||||
return torch.stack(imgs, dim=0), pids
|
||||
|
||||
|
||||
def val_collate_fn(batch):
|
||||
imgs, pids, camids, _ = zip(*batch)
|
||||
return torch.stack(imgs, dim=0), pids, camids
|
||||
def test_collate_fn(batch):
|
||||
imgs, label = zip(*batch)
|
||||
imgs = to_data(imgs)
|
||||
pids = []
|
||||
camids = []
|
||||
for i in label:
|
||||
pids.append(i.obj[0])
|
||||
camids.append(i.obj[1])
|
||||
return torch.stack(imgs, dim=0), (torch.LongTensor(pids), torch.LongTensor(camids))
|
||||
|
|
|
@ -3,18 +3,33 @@
|
|||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
try:
|
||||
from csrc.eval_cylib.eval_metrics_cy import evaluate_cy
|
||||
IS_CYTHON_AVAI = True
|
||||
print("Using Cython evaluation code as the backend")
|
||||
except ImportError:
|
||||
IS_CYTHON_AVAI = False
|
||||
warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended")
|
||||
|
||||
|
||||
def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
||||
"""Evaluation with market1501 metric
|
||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||
"""
|
||||
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
"""Evaluation with cuhk03 metric
|
||||
Key: one image for each gallery identity is randomly sampled for each query identity.
|
||||
Random sampling is performed num_repeats times.
|
||||
"""
|
||||
num_repeats = 10
|
||||
num_q, num_g = distmat.shape
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
|
@ -22,6 +37,7 @@ def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
|||
all_cmc = []
|
||||
all_AP = []
|
||||
num_valid_q = 0. # number of valid query
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
|
@ -33,13 +49,84 @@ def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
|||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
# binary vector, positions with value 1 are correct matches
|
||||
orig_cmc = matches[q_idx][keep]
|
||||
if not np.any(orig_cmc):
|
||||
raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(raw_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = orig_cmc.cumsum()
|
||||
kept_g_pids = g_pids[order][keep]
|
||||
g_pids_dict = defaultdict(list)
|
||||
for idx, pid in enumerate(kept_g_pids):
|
||||
g_pids_dict[pid].append(idx)
|
||||
|
||||
cmc, AP = 0., 0.
|
||||
for repeat_idx in range(num_repeats):
|
||||
mask = np.zeros(len(raw_cmc), dtype=np.bool)
|
||||
for _, idxs in g_pids_dict.items():
|
||||
# randomly sample one image for each gallery person
|
||||
rnd_idx = np.random.choice(idxs)
|
||||
mask[rnd_idx] = True
|
||||
masked_raw_cmc = raw_cmc[mask]
|
||||
_cmc = masked_raw_cmc.cumsum()
|
||||
_cmc[_cmc > 1] = 1
|
||||
cmc += _cmc[:max_rank].astype(np.float32)
|
||||
# compute AP
|
||||
num_rel = masked_raw_cmc.sum()
|
||||
tmp_cmc = masked_raw_cmc.cumsum()
|
||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * masked_raw_cmc
|
||||
AP += tmp_cmc.sum() / num_rel
|
||||
|
||||
cmc /= num_repeats
|
||||
AP /= num_repeats
|
||||
all_cmc.append(cmc)
|
||||
all_AP.append(AP)
|
||||
num_valid_q += 1.
|
||||
|
||||
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
|
||||
return all_cmc, mAP
|
||||
|
||||
|
||||
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
"""Evaluation with market1501 metric
|
||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||
"""
|
||||
num_q, num_g = distmat.shape
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
num_valid_q = 0. # number of valid query
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(raw_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = raw_cmc.cumsum()
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
|
@ -47,10 +134,10 @@ def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
|||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = orig_cmc.sum()
|
||||
tmp_cmc = orig_cmc.cumsum()
|
||||
num_rel = raw_cmc.sum()
|
||||
tmp_cmc = raw_cmc.cumsum()
|
||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
|
||||
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
|
||||
|
@ -61,3 +148,19 @@ def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
|||
mAP = np.mean(all_AP)
|
||||
|
||||
return all_cmc, mAP
|
||||
|
||||
|
||||
def evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03):
|
||||
if use_metric_cuhk03:
|
||||
return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
else:
|
||||
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
|
||||
|
||||
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False, use_cython=True):
|
||||
if use_cython and IS_CYTHON_AVAI:
|
||||
return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03)
|
||||
else:
|
||||
return evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03)
|
||||
|
||||
|
||||
|
|
|
@ -4,11 +4,13 @@
|
|||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import copy
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import random
|
||||
import copy
|
||||
import numpy as np
|
||||
import re
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
|
@ -23,12 +25,17 @@ class RandomIdentitySampler(Sampler):
|
|||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size, num_instances):
|
||||
pat = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, (_, pid, _) in enumerate(self.data_source):
|
||||
for index, fname in enumerate(self.data_source):
|
||||
prefix = fname.split('/')[1]
|
||||
pid, _ = pat.search(fname).groups()
|
||||
pid = prefix + '_' + pid
|
||||
self.index_dic[pid].append(index)
|
||||
self.pids = list(self.index_dic.keys())
|
||||
|
||||
|
@ -71,3 +78,27 @@ class RandomIdentitySampler(Sampler):
|
|||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
# class RandomIdentitySampler(Sampler):
|
||||
# def __init__(self, data_source, num_instances=4):
|
||||
# self.data_source = data_source
|
||||
# self.num_instances = num_instances
|
||||
# self.index_dic = defaultdict(list)
|
||||
# for index, (_, pid) in enumerate(data_source):
|
||||
# self.index_dic[pid].append(index)
|
||||
# self.pids = list(self.index_dic.keys())
|
||||
# self.num_identities = len(self.pids)
|
||||
#
|
||||
# def __iter__(self):
|
||||
# indices = torch.randperm(self.num_identities)
|
||||
# ret = []
|
||||
# for i in indices:
|
||||
# pid = self.pids[i]
|
||||
# t = self.index_dic[pid]
|
||||
# replace = False if len(t) >= self.num_instances else True
|
||||
# t = np.random.choice(t, size=self.num_instances, replace=replace)
|
||||
# ret.extend(t)
|
||||
# return iter(ret)
|
||||
#
|
||||
# def __len__(self):
|
||||
# return self.num_identities * self.num_instances
|
||||
|
|
|
@ -4,4 +4,4 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import build_transforms
|
||||
from .transforms import *
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import torchvision.transforms as T
|
||||
|
||||
from .transforms import RandomErasing
|
||||
|
||||
|
||||
def build_transforms(cfg, is_train=True):
|
||||
normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
||||
if is_train:
|
||||
transform = T.Compose([
|
||||
T.Resize(cfg.INPUT.SIZE_TRAIN),
|
||||
T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
|
||||
T.Pad(cfg.INPUT.PADDING),
|
||||
T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
|
||||
T.ToTensor(),
|
||||
normalize_transform,
|
||||
RandomErasing(probability=cfg.INPUT.PROB, mean=cfg.INPUT.PIXEL_MEAN)
|
||||
])
|
||||
else:
|
||||
transform = T.Compose([
|
||||
T.Resize(cfg.INPUT.SIZE_TEST),
|
||||
T.ToTensor(),
|
||||
normalize_transform
|
||||
])
|
||||
|
||||
return transform
|
|
@ -4,52 +4,34 @@
|
|||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
from fastai.vision import *
|
||||
from fastai.vision.image import *
|
||||
|
||||
|
||||
class RandomErasing(object):
|
||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||
'Random Erasing Data Augmentation' by Zhong et al.
|
||||
See https://arxiv.org/pdf/1708.04896.pdf
|
||||
Args:
|
||||
probability: The probability that the Random Erasing operation will be performed.
|
||||
sl: Minimum proportion of erased area against input image.
|
||||
sh: Maximum proportion of erased area against input image.
|
||||
r1: Minimum aspect ratio of erased area.
|
||||
mean: Erasing value.
|
||||
"""
|
||||
def _random_erasing(x, probability=0.5, sl=0.02, sh=0.4, r1=0.3,
|
||||
mean=(np.array(imagenet_stats[1]) + 1) * imagenet_stats[0]):
|
||||
if random.uniform(0, 1) > probability:
|
||||
return x
|
||||
|
||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
|
||||
self.probability = probability
|
||||
self.mean = mean
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.r1 = r1
|
||||
for attempt in range(100):
|
||||
area = x.size()[1] * x.size()[2]
|
||||
|
||||
def __call__(self, img):
|
||||
target_area = random.uniform(sl, sh) * area
|
||||
aspect_ratio = random.uniform(r1, 1 / r1)
|
||||
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return img
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
for attempt in range(100):
|
||||
area = img.size()[1] * img.size()[2]
|
||||
if w < x.size()[2] and h < x.size()[1]:
|
||||
x1 = random.randint(0, x.size()[1] - h)
|
||||
y1 = random.randint(0, x.size()[2] - w)
|
||||
if x.size()[0] == 3:
|
||||
x[0, x1:x1 + h, y1:y1 + w] = mean[0]
|
||||
x[1, x1:x1 + h, y1:y1 + w] = mean[1]
|
||||
x[2, x1:x1 + h, y1:y1 + w] = mean[2]
|
||||
else:
|
||||
x[0, x1:x1 + h, y1:y1 + w] = mean[0]
|
||||
return x
|
||||
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w < img.size()[2] and h < img.size()[1]:
|
||||
x1 = random.randint(0, img.size()[1] - h)
|
||||
y1 = random.randint(0, img.size()[2] - w)
|
||||
if img.size()[0] == 3:
|
||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
|
||||
img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
|
||||
img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
|
||||
else:
|
||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
|
||||
return img
|
||||
|
||||
return img
|
||||
RandomErasing = TfmPixel(_random_erasing)
|
||||
|
|
|
@ -4,147 +4,88 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from ignite.engine import Engine, Events
|
||||
from ignite.handlers import ModelCheckpoint, Timer
|
||||
from ignite.metrics import RunningAverage
|
||||
|
||||
from utils.reid_metric import R1_mAP
|
||||
from data.datasets.eval_reid import evaluate
|
||||
from fastai.vision import *
|
||||
|
||||
|
||||
def create_supervised_trainer(model, optimizer, loss_fn,
|
||||
device=None):
|
||||
"""
|
||||
Factory function for creating a trainer for supervised models
|
||||
class LrScheduler(LearnerCallback):
|
||||
def __init__(self, learn: Learner, lr_sched: Scheduler):
|
||||
super().__init__(learn)
|
||||
self.lr_sched = lr_sched
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): the model to train
|
||||
optimizer (`torch.optim.Optimizer`): the optimizer to use
|
||||
loss_fn (torch.nn loss function): the loss function to use
|
||||
device (str, optional): device type specification (default: None).
|
||||
Applies to both model and batches.
|
||||
def on_train_begin(self, **kwargs: Any) -> None:
|
||||
self.opt = self.learn.opt
|
||||
|
||||
Returns:
|
||||
Engine: a trainer engine with supervised update function
|
||||
"""
|
||||
if device:
|
||||
model.to(device)
|
||||
|
||||
def _update(engine, batch):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
img, target = batch
|
||||
img = img.cuda()
|
||||
target = target.cuda()
|
||||
score, feat = model(img)
|
||||
loss = loss_fn(score, feat, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# compute acc
|
||||
acc = (score.max(1)[1] == target).float().mean()
|
||||
return loss.item(), acc.item()
|
||||
|
||||
return Engine(_update)
|
||||
def on_epoch_begin(self, **kwargs: Any) -> None:
|
||||
self.opt.lr = self.lr_sched.step()
|
||||
|
||||
|
||||
def create_supervised_evaluator(model, metrics,
|
||||
device=None):
|
||||
"""
|
||||
Factory function for creating an evaluator for supervised models
|
||||
class TestModel(LearnerCallback):
|
||||
def __init__(self, learn: Learner, test_labels: Iterator, eval_period: int, num_query: int, output_dir: Path):
|
||||
super().__init__(learn)
|
||||
self.test_dl = learn.data.test_dl
|
||||
self.eval_period = eval_period
|
||||
self.output_dir = output_dir
|
||||
self.num_query = num_query
|
||||
pids = []
|
||||
camids = []
|
||||
for i in test_labels:
|
||||
pids.append(i[0])
|
||||
camids.append(i[1])
|
||||
self.q_pids = np.asarray(pids[:num_query])
|
||||
self.q_camids = np.asarray(camids[:num_query])
|
||||
self.g_pids = np.asarray(pids[num_query:])
|
||||
self.g_camids = np.asarray(camids[num_query:])
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): the model to train
|
||||
metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
|
||||
device (str, optional): device type specification (default: None).
|
||||
Applies to both model and batches.
|
||||
Returns:
|
||||
Engine: an evaluator engine with supervised inference function
|
||||
"""
|
||||
if device:
|
||||
model.to(device)
|
||||
def on_epoch_end(self, epoch, **kwargs: Any) -> None:
|
||||
# test model performance
|
||||
if (epoch + 1) % self.eval_period == 0:
|
||||
print('Testing ...')
|
||||
feats, pids, camids = [], [], []
|
||||
self.learn.model.eval()
|
||||
with torch.no_grad():
|
||||
for imgs, _ in self.test_dl:
|
||||
feat = self.learn.model(imgs)
|
||||
feats.append(feat)
|
||||
|
||||
def _inference(engine, batch):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
data, pids, camids = batch
|
||||
data = data.cuda()
|
||||
feat = model(data)
|
||||
return feat, pids, camids
|
||||
|
||||
engine = Engine(_inference)
|
||||
|
||||
for name, metric in metrics.items():
|
||||
metric.attach(engine, name)
|
||||
|
||||
return engine
|
||||
feats = torch.cat(feats, dim=0)
|
||||
# query
|
||||
qf = feats[:self.num_query]
|
||||
# gallery
|
||||
gf = feats[self.num_query:]
|
||||
m, n = qf.shape[0], gf.shape[0]
|
||||
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
distmat.addmm_(1, -2, qf, gf.t())
|
||||
distmat = to_np(distmat)
|
||||
cmc, mAP = evaluate(distmat, self.q_pids, self.g_pids, self.q_camids, self.g_camids)
|
||||
print("Test Results - Epoch: {}".format(epoch + 1))
|
||||
print("mAP: {:.1%}".format(mAP))
|
||||
for r in [1, 5, 10]:
|
||||
print("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
||||
self.learn.save(self.output_dir / 'reid_model_{}'.format(epoch))
|
||||
|
||||
|
||||
def do_train(
|
||||
cfg,
|
||||
model,
|
||||
train_loader,
|
||||
val_loader,
|
||||
optimizer,
|
||||
scheduler,
|
||||
loss_fn,
|
||||
data_bunch,
|
||||
test_labels,
|
||||
opt_func,
|
||||
lr_sched,
|
||||
loss_func,
|
||||
num_query
|
||||
):
|
||||
log_period = cfg.SOLVER.LOG_PERIOD
|
||||
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
|
||||
eval_period = cfg.SOLVER.EVAL_PERIOD
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
device = cfg.MODEL.DEVICE
|
||||
epochs = cfg.SOLVER.MAX_EPOCHS
|
||||
|
||||
logger = logging.getLogger("reid_baseline.train")
|
||||
logger.info("Start training")
|
||||
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
|
||||
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query)}, device=device)
|
||||
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
|
||||
timer = Timer(average=True)
|
||||
print("Start training")
|
||||
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict()})
|
||||
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
|
||||
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
|
||||
learn = Learner(data_bunch, model, opt_func=opt_func, loss_func=loss_func, true_wd=False)
|
||||
|
||||
# average metric to attach on trainer
|
||||
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
|
||||
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
|
||||
lr_sched_cb = LrScheduler(learn, lr_sched)
|
||||
testmodel_cb = TestModel(learn, test_labels, eval_period, num_query, Path(output_dir))
|
||||
|
||||
@trainer.on(Events.EPOCH_STARTED)
|
||||
def adjust_learning_rate(engine):
|
||||
scheduler.step()
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED)
|
||||
def log_training_loss(engine):
|
||||
iter = (engine.state.iteration - 1) % len(train_loader) + 1
|
||||
|
||||
if iter % log_period == 0:
|
||||
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
|
||||
.format(engine.state.epoch, iter, len(train_loader),
|
||||
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
|
||||
scheduler.get_lr()[0]))
|
||||
|
||||
# adding handlers using `trainer.on` decorator API
|
||||
@trainer.on(Events.EPOCH_COMPLETED)
|
||||
def print_times(engine):
|
||||
logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
|
||||
.format(engine.state.epoch, timer.value() * timer.step_count,
|
||||
train_loader.batch_size / timer.value()))
|
||||
logger.info('-' * 10)
|
||||
timer.reset()
|
||||
|
||||
@trainer.on(Events.EPOCH_COMPLETED)
|
||||
def log_validation_results(engine):
|
||||
if engine.state.epoch % eval_period == 0:
|
||||
evaluator.run(val_loader)
|
||||
cmc, mAP = evaluator.state.metrics['r1_mAP']
|
||||
logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
|
||||
logger.info("mAP: {:.1%}".format(mAP))
|
||||
for r in [1, 5, 10]:
|
||||
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
||||
|
||||
trainer.run(train_loader, max_epochs=epochs)
|
||||
learn.fit(epochs, callbacks=[lr_sched_cb, testmodel_cb],
|
||||
lr=cfg.SOLVER.BASE_LR, wd=cfg.SOLVER.WEIGHT_DECAY)
|
||||
|
|
|
@ -14,13 +14,16 @@ def make_loss(cfg):
|
|||
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
|
||||
|
||||
if sampler == 'softmax':
|
||||
def loss_func(score, feat, target):
|
||||
def loss_func(out, target):
|
||||
score, feat = out
|
||||
return F.cross_entropy(score, target)
|
||||
elif cfg.DATALOADER.SAMPLER == 'triplet':
|
||||
def loss_func(score, feat, target):
|
||||
def loss_func(out, target):
|
||||
score, feat = out
|
||||
return triplet(feat, target)[0]
|
||||
elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
|
||||
def loss_func(score, feat, target):
|
||||
def loss_func(out, target):
|
||||
score, feat = out
|
||||
return F.cross_entropy(score, target) + triplet(feat, target)[0]
|
||||
else:
|
||||
print('expected sampler should be softmax, triplet or softmax_triplet, '
|
||||
|
|
|
@ -4,5 +4,4 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import make_optimizer
|
||||
from .lr_scheduler import WarmupMultiStepLR
|
||||
from .build import *
|
||||
|
|
|
@ -4,22 +4,12 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from fastai.vision import *
|
||||
|
||||
|
||||
def make_optimizer(cfg, model):
|
||||
params = []
|
||||
for key, value in model.named_parameters():
|
||||
if not value.requires_grad:
|
||||
continue
|
||||
lr = cfg.SOLVER.BASE_LR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
||||
if "bias" in key:
|
||||
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
||||
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||
def make_optimizer(cfg):
|
||||
if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
|
||||
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||
opt = partial(getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME), momentum=cfg.SOLVER.MOMENTUM)
|
||||
else:
|
||||
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
|
||||
return optimizer
|
||||
opt = partial(getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME))
|
||||
return opt
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from bisect import bisect_right
|
||||
import torch
|
||||
|
||||
|
||||
# FIXME ideally this would be achieved with a CombinedLRScheduler,
|
||||
# separating MultiStepLR with WarmupLR
|
||||
# but the current LRScheduler design doesn't allow it
|
||||
|
||||
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
milestones,
|
||||
gamma=0.1,
|
||||
warmup_factor=1.0 / 3,
|
||||
warmup_iters=500,
|
||||
warmup_method="linear",
|
||||
last_epoch=-1,
|
||||
):
|
||||
if not list(milestones) == sorted(milestones):
|
||||
raise ValueError(
|
||||
"Milestones should be a list of" " increasing integers. Got {}",
|
||||
milestones,
|
||||
)
|
||||
|
||||
if warmup_method not in ("constant", "linear"):
|
||||
raise ValueError(
|
||||
"Only 'constant' or 'linear' warmup_method accepted"
|
||||
"got {}".format(warmup_method)
|
||||
)
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
self.warmup_factor = warmup_factor
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_method = warmup_method
|
||||
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_factor = 1
|
||||
if self.last_epoch < self.warmup_iters:
|
||||
if self.warmup_method == "constant":
|
||||
warmup_factor = self.warmup_factor
|
||||
elif self.warmup_method == "linear":
|
||||
alpha = self.last_epoch / self.warmup_iters
|
||||
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
|
||||
return [
|
||||
base_lr
|
||||
* warmup_factor
|
||||
* self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
|
@ -1,6 +1,6 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
@ -14,10 +14,9 @@ from torch.backends import cudnn
|
|||
|
||||
sys.path.append('.')
|
||||
from config import cfg
|
||||
from data import make_data_loader
|
||||
from data import get_data_bunch
|
||||
from engine.inference import inference
|
||||
from modeling import build_model
|
||||
from utils.logger import setup_logger
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -54,11 +53,11 @@ def main():
|
|||
|
||||
cudnn.benchmark = True
|
||||
|
||||
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
|
||||
model = build_model(cfg, num_classes)
|
||||
train_databunch, test_databunch, num_query = get_data_bunch(cfg)
|
||||
model = build_model(cfg, train_databunch.c)
|
||||
model.load_state_dict(torch.load(cfg.TEST.WEIGHT))
|
||||
|
||||
inference(cfg, model, val_loader, num_query)
|
||||
inference(cfg, model, test_databunch, num_query)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -5,43 +5,51 @@
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from bisect import bisect_right
|
||||
|
||||
from torch.backends import cudnn
|
||||
|
||||
sys.path.append('.')
|
||||
from config import cfg
|
||||
from data import make_data_loader
|
||||
from data import get_data_bunch
|
||||
from engine.trainer import do_train
|
||||
from modeling import build_model
|
||||
from layers import make_loss
|
||||
from solver import make_optimizer, WarmupMultiStepLR
|
||||
|
||||
from utils.logger import setup_logger
|
||||
from modeling import build_model
|
||||
from utils.logger import Logger
|
||||
from fastai.vision import *
|
||||
|
||||
|
||||
def train(cfg):
|
||||
# prepare dataset
|
||||
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
|
||||
# prepare model
|
||||
model = build_model(cfg, num_classes)
|
||||
data_bunch, test_labels, num_query = get_data_bunch(cfg)
|
||||
|
||||
optimizer = make_optimizer(cfg, model)
|
||||
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
||||
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
|
||||
# prepare model
|
||||
model = build_model(cfg, data_bunch.c)
|
||||
|
||||
opt_func = partial(torch.optim.Adam)
|
||||
|
||||
def warmup_multistep(start: float, end: float, pct: float) -> float:
|
||||
warmup_factor = 1
|
||||
gamma = cfg.SOLVER.GAMMA
|
||||
milestones = [1.0 * s / cfg.SOLVER.MAX_EPOCHS for s in cfg.SOLVER.STEPS]
|
||||
warmup_iter = 1.0 * cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_EPOCHS
|
||||
if pct < warmup_iter:
|
||||
alpha = pct / warmup_iter
|
||||
warmup_factor = cfg.SOLVER.WARMUP_FACTOR * (1 - alpha) + alpha
|
||||
return start * warmup_factor * gamma ** bisect_right(milestones, pct)
|
||||
|
||||
lr_sched = Scheduler((cfg.SOLVER.BASE_LR, 0), cfg.SOLVER.MAX_EPOCHS, warmup_multistep)
|
||||
|
||||
loss_func = make_loss(cfg)
|
||||
|
||||
arguments = {}
|
||||
|
||||
do_train(
|
||||
cfg,
|
||||
model,
|
||||
train_loader,
|
||||
val_loader,
|
||||
optimizer,
|
||||
scheduler,
|
||||
data_bunch,
|
||||
test_labels,
|
||||
opt_func,
|
||||
lr_sched,
|
||||
loss_func,
|
||||
num_query
|
||||
)
|
||||
|
@ -64,20 +72,15 @@ def main():
|
|||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
logger = setup_logger("reid_baseline", output_dir, 0)
|
||||
logger.info("Using {} GPUS".format(num_gpus))
|
||||
logger.info(args)
|
||||
sys.stdout = Logger(os.path.join(cfg.OUTPUT_DIR, 'log.txt'))
|
||||
print(args)
|
||||
|
||||
if args.config_file != "":
|
||||
logger.info("Loaded configuration file {}".format(args.config_file))
|
||||
print("Loaded configuration file {}".format(args.config_file))
|
||||
with open(args.config_file, 'r') as cf:
|
||||
config_str = "\n" + cf.read()
|
||||
logger.info(config_str)
|
||||
logger.info("Running with config:\n{}".format(cfg))
|
||||
print(config_str)
|
||||
print("Running with config:\n{}".format(cfg))
|
||||
|
||||
cudnn.benchmark = True
|
||||
train(cfg)
|
||||
|
|
|
@ -4,27 +4,48 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def setup_logger(name, save_dir, distributed_rank):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# don't log results for the non-master process
|
||||
if distributed_rank > 0:
|
||||
return logger
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
def mkdir_if_missing(dir_path):
|
||||
try:
|
||||
os.makedirs(dir_path)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
if save_dir:
|
||||
fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w')
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
||||
class Logger(object):
|
||||
def __init__(self, fpath=None):
|
||||
self.console = sys.stdout
|
||||
self.file = None
|
||||
if fpath is not None:
|
||||
mkdir_if_missing(os.path.dirname(fpath))
|
||||
self.file = open(fpath, 'w')
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
def write(self, msg):
|
||||
self.console.write(msg)
|
||||
if self.file is not None:
|
||||
self.file.write(msg)
|
||||
|
||||
def flush(self):
|
||||
self.console.flush()
|
||||
if self.file is not None:
|
||||
self.file.flush()
|
||||
os.fsync(self.file.fileno())
|
||||
|
||||
def close(self):
|
||||
self.console.close()
|
||||
if self.file is not None:
|
||||
self.file.close()
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from ignite.metrics import Metric
|
||||
|
||||
from data.datasets.eval_reid import eval_func
|
||||
|
||||
|
||||
class R1_mAP(Metric):
|
||||
def __init__(self, num_query, max_rank=50):
|
||||
super(R1_mAP, self).__init__()
|
||||
self.num_query = num_query
|
||||
self.max_rank = max_rank
|
||||
|
||||
def reset(self):
|
||||
self.feats = []
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
|
||||
def update(self, output):
|
||||
feat, pid, camid = output
|
||||
self.feats.append(feat)
|
||||
self.pids.extend(np.asarray(pid))
|
||||
self.camids.extend(np.asarray(camid))
|
||||
|
||||
def compute(self):
|
||||
feats = torch.cat(self.feats, dim=0)
|
||||
# query
|
||||
qf = feats[:self.num_query]
|
||||
q_pids = np.asarray(self.pids[:self.num_query])
|
||||
q_camids = np.asarray(self.camids[:self.num_query])
|
||||
# gallery
|
||||
gf = feats[self.num_query:]
|
||||
g_pids = np.asarray(self.pids[self.num_query:])
|
||||
g_camids = np.asarray(self.camids[self.num_query:])
|
||||
m, n = qf.shape[0], gf.shape[0]
|
||||
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
distmat.addmm_(1, -2, qf, gf.t())
|
||||
distmat = distmat.cpu().numpy()
|
||||
cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
|
||||
return cmc, mAP
|
Loading…
Reference in New Issue