mirror of https://github.com/JDAI-CV/fast-reid.git
parent
369325906f
commit
48518552cf
|
@ -43,7 +43,7 @@ _C.INPUT.PADDING = 10
|
|||
# -----------------------------------------------------------------------------
|
||||
_C.DATASETS = CN()
|
||||
# List of the dataset names for training, as present in paths_catalog.py
|
||||
_C.DATASETS.NAMES = ()
|
||||
_C.DATASETS.NAMES = ("cuhk03",)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DataLoader
|
||||
|
|
|
@ -11,6 +11,7 @@ import re
|
|||
from fastai.vision import *
|
||||
from .transforms import RandomErasing
|
||||
from .samplers import RandomIdentitySampler
|
||||
from .datasets import CUHK03
|
||||
|
||||
|
||||
def get_data_bunch(cfg):
|
||||
|
@ -39,7 +40,6 @@ def get_data_bunch(cfg):
|
|||
v_paths.append([img_path,pid,camid])
|
||||
return v_paths
|
||||
|
||||
|
||||
market_train_path = 'datasets/Market-1501-v15.09.15/bounding_box_train'
|
||||
duke_train_path = 'datasets/DukeMTMC-reID/bounding_box_train'
|
||||
cuhk03_train_path = 'datasets/cuhk03/'
|
||||
|
@ -58,12 +58,11 @@ def get_data_bunch(cfg):
|
|||
train_img_names.extend(_process_dir(duke_train_path))
|
||||
elif d == 'beijing':
|
||||
train_img_names.extend(_process_dir(bjStation_train_path, True))
|
||||
elif d == 'cuhk03':
|
||||
train_img_names.extend(CUHK03().train)
|
||||
else:
|
||||
raise NameError("{} is not available".format(d))
|
||||
|
||||
# train_img_names = _process_dir(market_train_path) + _process_dir(duke_train_path) + _process_dir(bjStation_train_path)
|
||||
# train_img_names = _process_dir(market_train_path)
|
||||
# train_img_names = CUHK03().train
|
||||
train_names = [i[0] for i in train_img_names]
|
||||
|
||||
query_names = _process_dir(bj_query_path)
|
||||
|
@ -77,17 +76,17 @@ def get_data_bunch(cfg):
|
|||
|
||||
def get_labels(file_path):
|
||||
""" Suitable for muilti-dataset training """
|
||||
# if 'cuhk03' in file_path:
|
||||
# prefix = 'cuhk'
|
||||
# pid = file_path.split('/')[-1].split('_')[1]
|
||||
# else:
|
||||
if 'cuhk03' in file_path:
|
||||
prefix = 'cuhk'
|
||||
pid = '_'.join(file_path.split('/')[-1].split('_')[0:2])
|
||||
else:
|
||||
prefix = file_path.split('/')[1]
|
||||
pat = re.compile(r'([-\d]+)_c(\d)')
|
||||
pid, _ = pat.search(file_path).groups()
|
||||
return prefix + '_' + pid
|
||||
|
||||
data_bunch = ImageDataBunch.from_name_func('datasets', train_names, label_func=get_labels, valid_pct=0,
|
||||
size=(256, 128), ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
|
||||
size=cfg.INPUT.SIZE_TRAIN, ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
|
||||
val_bs=cfg.TEST.IMS_PER_BATCH)
|
||||
|
||||
if 'triplet' in cfg.DATALOADER.SAMPLER:
|
||||
|
|
|
@ -34,7 +34,11 @@ class RandomIdentitySampler(Sampler):
|
|||
self.index_dic = defaultdict(list)
|
||||
for index, fname in enumerate(self.data_source):
|
||||
prefix = fname.split('/')[1]
|
||||
try:
|
||||
pid, _ = pat.search(fname).groups()
|
||||
except:
|
||||
prefix = fname.split('/')[4]
|
||||
pid = '_'.join(fname.split('/')[-1].split('_')[:2])
|
||||
pid = prefix + '_' + pid
|
||||
self.index_dic[pid].append(index)
|
||||
self.pids = list(self.index_dic.keys())
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
gpu=3
|
||||
gpu=2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$gpu python tools/test.py -cfg='configs/softmax_triplet.yml' \
|
||||
INPUT.SIZE_TRAIN '(384, 128)' \
|
||||
DATASETS.NAMES '("market1501","duke","beijing")' \
|
||||
OUTPUT_DIR 'logs/test' \
|
||||
TEST.WEIGHT 'logs/beijing/market+duke+bj/models/model_149.pth'
|
|
@ -0,0 +1,5 @@
|
|||
gpu=3
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.NAMES '("market1501","beijing",)' \
|
||||
OUTPUT_DIR 'logs/beijing/market_bj'
|
|
@ -0,0 +1,5 @@
|
|||
gpu=2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.NAMES '("market1501","duke","cuhk03","beijing")' \
|
||||
OUTPUT_DIR 'logs/beijing/market_duke_cuhk03_beijing_256x128_bs512'
|
|
@ -50,10 +50,11 @@ def main():
|
|||
cudnn.benchmark = True
|
||||
|
||||
data_bunch, test_labels, num_query = get_data_bunch(cfg)
|
||||
model = build_model(cfg, data_bunch.c)
|
||||
state_dict = torch.load(cfg.TEST.WEIGHT)
|
||||
model.load_state_dict(state_dict['model'])
|
||||
model.cuda()
|
||||
# model = build_model(cfg, data_bunch.c)
|
||||
# state_dict = torch.load(cfg.TEST.WEIGHT)
|
||||
# model.load_state_dict(state_dict['model'])
|
||||
# model.cuda()
|
||||
model = torch.jit.load("/export/home/lxy/reid_baseline/pcb_model_v0.2.pt")
|
||||
|
||||
inference(cfg, model, data_bunch, test_labels, num_query)
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ def train(cfg):
|
|||
|
||||
# prepare model
|
||||
model = build_model(cfg, data_bunch.c)
|
||||
state_dict = torch.load("logs/beijing/market_duke_softmax_triplet_256_128_bs512/models/model_149.pth")
|
||||
model.load_params_wo_fc(state_dict['model'])
|
||||
# state_dict = torch.load("logs/beijing/market_duke_softmax_triplet_256_128_bs512/models/model_149.pth")
|
||||
# model.load_params_wo_fc(state_dict['model'])
|
||||
|
||||
opt_func = partial(torch.optim.Adam)
|
||||
|
||||
|
|
Loading…
Reference in New Issue