Add cuhk03
pull/25/head
liaoxingyu 2019-08-01 15:49:31 +08:00
parent 369325906f
commit 48518552cf
8 changed files with 36 additions and 21 deletions

View File

@ -43,7 +43,7 @@ _C.INPUT.PADDING = 10
# -----------------------------------------------------------------------------
_C.DATASETS = CN()
# List of the dataset names for training, as present in paths_catalog.py
_C.DATASETS.NAMES = ()
_C.DATASETS.NAMES = ("cuhk03",)
# -----------------------------------------------------------------------------
# DataLoader

View File

@ -11,6 +11,7 @@ import re
from fastai.vision import *
from .transforms import RandomErasing
from .samplers import RandomIdentitySampler
from .datasets import CUHK03
def get_data_bunch(cfg):
@ -39,7 +40,6 @@ def get_data_bunch(cfg):
v_paths.append([img_path,pid,camid])
return v_paths
market_train_path = 'datasets/Market-1501-v15.09.15/bounding_box_train'
duke_train_path = 'datasets/DukeMTMC-reID/bounding_box_train'
cuhk03_train_path = 'datasets/cuhk03/'
@ -58,12 +58,11 @@ def get_data_bunch(cfg):
train_img_names.extend(_process_dir(duke_train_path))
elif d == 'beijing':
train_img_names.extend(_process_dir(bjStation_train_path, True))
elif d == 'cuhk03':
train_img_names.extend(CUHK03().train)
else:
raise NameError("{} is not available".format(d))
# train_img_names = _process_dir(market_train_path) + _process_dir(duke_train_path) + _process_dir(bjStation_train_path)
# train_img_names = _process_dir(market_train_path)
# train_img_names = CUHK03().train
train_names = [i[0] for i in train_img_names]
query_names = _process_dir(bj_query_path)
@ -77,17 +76,17 @@ def get_data_bunch(cfg):
def get_labels(file_path):
""" Suitable for muilti-dataset training """
# if 'cuhk03' in file_path:
# prefix = 'cuhk'
# pid = file_path.split('/')[-1].split('_')[1]
# else:
prefix = file_path.split('/')[1]
pat = re.compile(r'([-\d]+)_c(\d)')
pid, _ = pat.search(file_path).groups()
if 'cuhk03' in file_path:
prefix = 'cuhk'
pid = '_'.join(file_path.split('/')[-1].split('_')[0:2])
else:
prefix = file_path.split('/')[1]
pat = re.compile(r'([-\d]+)_c(\d)')
pid, _ = pat.search(file_path).groups()
return prefix + '_' + pid
data_bunch = ImageDataBunch.from_name_func('datasets', train_names, label_func=get_labels, valid_pct=0,
size=(256, 128), ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
size=cfg.INPUT.SIZE_TRAIN, ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
val_bs=cfg.TEST.IMS_PER_BATCH)
if 'triplet' in cfg.DATALOADER.SAMPLER:

View File

@ -34,7 +34,11 @@ class RandomIdentitySampler(Sampler):
self.index_dic = defaultdict(list)
for index, fname in enumerate(self.data_source):
prefix = fname.split('/')[1]
pid, _ = pat.search(fname).groups()
try:
pid, _ = pat.search(fname).groups()
except:
prefix = fname.split('/')[4]
pid = '_'.join(fname.split('/')[-1].split('_')[:2])
pid = prefix + '_' + pid
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())

View File

@ -1,6 +1,7 @@
gpu=3
gpu=2
CUDA_VISIBLE_DEVICES=$gpu python tools/test.py -cfg='configs/softmax_triplet.yml' \
INPUT.SIZE_TRAIN '(384, 128)' \
DATASETS.NAMES '("market1501","duke","beijing")' \
OUTPUT_DIR 'logs/test' \
TEST.WEIGHT 'logs/beijing/market+duke+bj/models/model_149.pth'

View File

@ -0,0 +1,5 @@
gpu=3
CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
DATASETS.NAMES '("market1501","beijing",)' \
OUTPUT_DIR 'logs/beijing/market_bj'

View File

@ -0,0 +1,5 @@
gpu=2
CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
DATASETS.NAMES '("market1501","duke","cuhk03","beijing")' \
OUTPUT_DIR 'logs/beijing/market_duke_cuhk03_beijing_256x128_bs512'

View File

@ -50,10 +50,11 @@ def main():
cudnn.benchmark = True
data_bunch, test_labels, num_query = get_data_bunch(cfg)
model = build_model(cfg, data_bunch.c)
state_dict = torch.load(cfg.TEST.WEIGHT)
model.load_state_dict(state_dict['model'])
model.cuda()
# model = build_model(cfg, data_bunch.c)
# state_dict = torch.load(cfg.TEST.WEIGHT)
# model.load_state_dict(state_dict['model'])
# model.cuda()
model = torch.jit.load("/export/home/lxy/reid_baseline/pcb_model_v0.2.pt")
inference(cfg, model, data_bunch, test_labels, num_query)

View File

@ -27,8 +27,8 @@ def train(cfg):
# prepare model
model = build_model(cfg, data_bunch.c)
state_dict = torch.load("logs/beijing/market_duke_softmax_triplet_256_128_bs512/models/model_149.pth")
model.load_params_wo_fc(state_dict['model'])
# state_dict = torch.load("logs/beijing/market_duke_softmax_triplet_256_128_bs512/models/model_149.pth")
# model.load_params_wo_fc(state_dict['model'])
opt_func = partial(torch.optim.Adam)