mirror of https://github.com/JDAI-CV/fast-reid.git
parent
369325906f
commit
48518552cf
|
@ -43,7 +43,7 @@ _C.INPUT.PADDING = 10
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
_C.DATASETS = CN()
|
_C.DATASETS = CN()
|
||||||
# List of the dataset names for training, as present in paths_catalog.py
|
# List of the dataset names for training, as present in paths_catalog.py
|
||||||
_C.DATASETS.NAMES = ()
|
_C.DATASETS.NAMES = ("cuhk03",)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# DataLoader
|
# DataLoader
|
||||||
|
|
|
@ -11,6 +11,7 @@ import re
|
||||||
from fastai.vision import *
|
from fastai.vision import *
|
||||||
from .transforms import RandomErasing
|
from .transforms import RandomErasing
|
||||||
from .samplers import RandomIdentitySampler
|
from .samplers import RandomIdentitySampler
|
||||||
|
from .datasets import CUHK03
|
||||||
|
|
||||||
|
|
||||||
def get_data_bunch(cfg):
|
def get_data_bunch(cfg):
|
||||||
|
@ -39,7 +40,6 @@ def get_data_bunch(cfg):
|
||||||
v_paths.append([img_path,pid,camid])
|
v_paths.append([img_path,pid,camid])
|
||||||
return v_paths
|
return v_paths
|
||||||
|
|
||||||
|
|
||||||
market_train_path = 'datasets/Market-1501-v15.09.15/bounding_box_train'
|
market_train_path = 'datasets/Market-1501-v15.09.15/bounding_box_train'
|
||||||
duke_train_path = 'datasets/DukeMTMC-reID/bounding_box_train'
|
duke_train_path = 'datasets/DukeMTMC-reID/bounding_box_train'
|
||||||
cuhk03_train_path = 'datasets/cuhk03/'
|
cuhk03_train_path = 'datasets/cuhk03/'
|
||||||
|
@ -58,12 +58,11 @@ def get_data_bunch(cfg):
|
||||||
train_img_names.extend(_process_dir(duke_train_path))
|
train_img_names.extend(_process_dir(duke_train_path))
|
||||||
elif d == 'beijing':
|
elif d == 'beijing':
|
||||||
train_img_names.extend(_process_dir(bjStation_train_path, True))
|
train_img_names.extend(_process_dir(bjStation_train_path, True))
|
||||||
|
elif d == 'cuhk03':
|
||||||
|
train_img_names.extend(CUHK03().train)
|
||||||
else:
|
else:
|
||||||
raise NameError("{} is not available".format(d))
|
raise NameError("{} is not available".format(d))
|
||||||
|
|
||||||
# train_img_names = _process_dir(market_train_path) + _process_dir(duke_train_path) + _process_dir(bjStation_train_path)
|
|
||||||
# train_img_names = _process_dir(market_train_path)
|
|
||||||
# train_img_names = CUHK03().train
|
|
||||||
train_names = [i[0] for i in train_img_names]
|
train_names = [i[0] for i in train_img_names]
|
||||||
|
|
||||||
query_names = _process_dir(bj_query_path)
|
query_names = _process_dir(bj_query_path)
|
||||||
|
@ -77,17 +76,17 @@ def get_data_bunch(cfg):
|
||||||
|
|
||||||
def get_labels(file_path):
|
def get_labels(file_path):
|
||||||
""" Suitable for muilti-dataset training """
|
""" Suitable for muilti-dataset training """
|
||||||
# if 'cuhk03' in file_path:
|
if 'cuhk03' in file_path:
|
||||||
# prefix = 'cuhk'
|
prefix = 'cuhk'
|
||||||
# pid = file_path.split('/')[-1].split('_')[1]
|
pid = '_'.join(file_path.split('/')[-1].split('_')[0:2])
|
||||||
# else:
|
else:
|
||||||
prefix = file_path.split('/')[1]
|
prefix = file_path.split('/')[1]
|
||||||
pat = re.compile(r'([-\d]+)_c(\d)')
|
pat = re.compile(r'([-\d]+)_c(\d)')
|
||||||
pid, _ = pat.search(file_path).groups()
|
pid, _ = pat.search(file_path).groups()
|
||||||
return prefix + '_' + pid
|
return prefix + '_' + pid
|
||||||
|
|
||||||
data_bunch = ImageDataBunch.from_name_func('datasets', train_names, label_func=get_labels, valid_pct=0,
|
data_bunch = ImageDataBunch.from_name_func('datasets', train_names, label_func=get_labels, valid_pct=0,
|
||||||
size=(256, 128), ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
|
size=cfg.INPUT.SIZE_TRAIN, ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
|
||||||
val_bs=cfg.TEST.IMS_PER_BATCH)
|
val_bs=cfg.TEST.IMS_PER_BATCH)
|
||||||
|
|
||||||
if 'triplet' in cfg.DATALOADER.SAMPLER:
|
if 'triplet' in cfg.DATALOADER.SAMPLER:
|
||||||
|
|
|
@ -34,7 +34,11 @@ class RandomIdentitySampler(Sampler):
|
||||||
self.index_dic = defaultdict(list)
|
self.index_dic = defaultdict(list)
|
||||||
for index, fname in enumerate(self.data_source):
|
for index, fname in enumerate(self.data_source):
|
||||||
prefix = fname.split('/')[1]
|
prefix = fname.split('/')[1]
|
||||||
|
try:
|
||||||
pid, _ = pat.search(fname).groups()
|
pid, _ = pat.search(fname).groups()
|
||||||
|
except:
|
||||||
|
prefix = fname.split('/')[4]
|
||||||
|
pid = '_'.join(fname.split('/')[-1].split('_')[:2])
|
||||||
pid = prefix + '_' + pid
|
pid = prefix + '_' + pid
|
||||||
self.index_dic[pid].append(index)
|
self.index_dic[pid].append(index)
|
||||||
self.pids = list(self.index_dic.keys())
|
self.pids = list(self.index_dic.keys())
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
gpu=3
|
gpu=2
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=$gpu python tools/test.py -cfg='configs/softmax_triplet.yml' \
|
CUDA_VISIBLE_DEVICES=$gpu python tools/test.py -cfg='configs/softmax_triplet.yml' \
|
||||||
|
INPUT.SIZE_TRAIN '(384, 128)' \
|
||||||
DATASETS.NAMES '("market1501","duke","beijing")' \
|
DATASETS.NAMES '("market1501","duke","beijing")' \
|
||||||
OUTPUT_DIR 'logs/test' \
|
OUTPUT_DIR 'logs/test' \
|
||||||
TEST.WEIGHT 'logs/beijing/market+duke+bj/models/model_149.pth'
|
TEST.WEIGHT 'logs/beijing/market+duke+bj/models/model_149.pth'
|
|
@ -0,0 +1,5 @@
|
||||||
|
gpu=3
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||||
|
DATASETS.NAMES '("market1501","beijing",)' \
|
||||||
|
OUTPUT_DIR 'logs/beijing/market_bj'
|
|
@ -0,0 +1,5 @@
|
||||||
|
gpu=2
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||||
|
DATASETS.NAMES '("market1501","duke","cuhk03","beijing")' \
|
||||||
|
OUTPUT_DIR 'logs/beijing/market_duke_cuhk03_beijing_256x128_bs512'
|
|
@ -50,10 +50,11 @@ def main():
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
|
|
||||||
data_bunch, test_labels, num_query = get_data_bunch(cfg)
|
data_bunch, test_labels, num_query = get_data_bunch(cfg)
|
||||||
model = build_model(cfg, data_bunch.c)
|
# model = build_model(cfg, data_bunch.c)
|
||||||
state_dict = torch.load(cfg.TEST.WEIGHT)
|
# state_dict = torch.load(cfg.TEST.WEIGHT)
|
||||||
model.load_state_dict(state_dict['model'])
|
# model.load_state_dict(state_dict['model'])
|
||||||
model.cuda()
|
# model.cuda()
|
||||||
|
model = torch.jit.load("/export/home/lxy/reid_baseline/pcb_model_v0.2.pt")
|
||||||
|
|
||||||
inference(cfg, model, data_bunch, test_labels, num_query)
|
inference(cfg, model, data_bunch, test_labels, num_query)
|
||||||
|
|
||||||
|
|
|
@ -27,8 +27,8 @@ def train(cfg):
|
||||||
|
|
||||||
# prepare model
|
# prepare model
|
||||||
model = build_model(cfg, data_bunch.c)
|
model = build_model(cfg, data_bunch.c)
|
||||||
state_dict = torch.load("logs/beijing/market_duke_softmax_triplet_256_128_bs512/models/model_149.pth")
|
# state_dict = torch.load("logs/beijing/market_duke_softmax_triplet_256_128_bs512/models/model_149.pth")
|
||||||
model.load_params_wo_fc(state_dict['model'])
|
# model.load_params_wo_fc(state_dict['model'])
|
||||||
|
|
||||||
opt_func = partial(torch.optim.Adam)
|
opt_func = partial(torch.optim.Adam)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue