mirror of https://github.com/JDAI-CV/fast-reid.git
parent
defc190f6b
commit
08d105560c
|
@ -4,3 +4,5 @@ __pycache__
|
|||
.vscode
|
||||
datasets
|
||||
csrc/eval_cylib/build/
|
||||
logs/
|
||||
*.ipynb
|
|
@ -35,6 +35,6 @@ TEST:
|
|||
IMS_PER_BATCH: 512
|
||||
WEIGHT: "path"
|
||||
|
||||
OUTPUT_DIR: "logs/co-train/test"
|
||||
OUTPUT_DIR: "logs/beijing/market+duke+bj/"
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,13 @@ def get_data_bunch(cfg):
|
|||
)
|
||||
|
||||
def _process_dir(dir_path):
|
||||
img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
|
||||
img_paths = []
|
||||
if 'beijingStation' in dir_path:
|
||||
id_dirs = os.listdir(dir_path)
|
||||
for d in id_dirs:
|
||||
img_paths.extend(glob.glob(os.path.join(dir_path, d, '*.jpg')))
|
||||
else:
|
||||
img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
pid_container = set()
|
||||
|
@ -37,12 +43,14 @@ def get_data_bunch(cfg):
|
|||
|
||||
market_train_path = 'datasets/Market-1501-v15.09.15/bounding_box_train'
|
||||
duke_train_path = 'datasets/DukeMTMC-reID/bounding_box_train'
|
||||
cuhk03_train_path = ''
|
||||
cuhk03_train_path = 'datasets/cuhk03/'
|
||||
bjStation_train_path = 'datasets/beijingStation/20190720/train'
|
||||
|
||||
query_path = 'datasets/Market-1501-v15.09.15/query'
|
||||
gallery_path = 'datasets/Market-1501-v15.09.15/bounding_box_test'
|
||||
|
||||
# train_img_names = _process_dir(market_train_path) + _process_dir(duke_train_path)
|
||||
train_img_names = _process_dir(market_train_path)
|
||||
train_img_names = _process_dir(market_train_path) + _process_dir(duke_train_path) + _process_dir(bjStation_train_path)
|
||||
# train_img_names = CUHK03().train
|
||||
train_names = [i[0] for i in train_img_names]
|
||||
|
||||
query_names = _process_dir(query_path)
|
||||
|
@ -55,6 +63,10 @@ def get_data_bunch(cfg):
|
|||
|
||||
def get_labels(file_path):
|
||||
""" Suitable for muilti-dataset training """
|
||||
# if 'cuhk03' in file_path:
|
||||
# prefix = 'cuhk'
|
||||
# pid = file_path.split('/')[-1].split('_')[1]
|
||||
# else:
|
||||
prefix = file_path.split('/')[1]
|
||||
pat = re.compile(r'([-\d]+)_c(\d)')
|
||||
pid, _ = pat.search(file_path).groups()
|
||||
|
@ -63,6 +75,7 @@ def get_data_bunch(cfg):
|
|||
data_bunch = ImageDataBunch.from_name_func('datasets', train_names, label_func=get_labels, valid_pct=0,
|
||||
size=(256, 128), ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
|
||||
val_bs=cfg.TEST.IMS_PER_BATCH)
|
||||
|
||||
if 'triplet' in cfg.DATALOADER.SAMPLER:
|
||||
data_sampler = RandomIdentitySampler(train_names, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
|
||||
data_bunch.train_dl = data_bunch.train_dl.new(shuffle=False, sampler=data_sampler)
|
||||
|
|
|
@ -31,7 +31,7 @@ class CUHK03(BaseImageDataset):
|
|||
"""
|
||||
dataset_dir = 'cuhk03'
|
||||
|
||||
def __init__(self, root='/export/home/lxy/DATA/reid', split_id=0, cuhk03_labeled=False,
|
||||
def __init__(self, root='datasets', split_id=0, cuhk03_labeled=False,
|
||||
cuhk03_classic_split=False, verbose=True,
|
||||
**kwargs):
|
||||
super(CUHK03, self).__init__()
|
||||
|
|
|
@ -81,6 +81,7 @@ class TestModel(LearnerCallback):
|
|||
self._logger.info("mAP: {:.1%}".format(mAP))
|
||||
for r in [1, 5, 10]:
|
||||
self._logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
||||
self.learn.save("model_{}".format(epoch))
|
||||
|
||||
def do_train(
|
||||
cfg,
|
||||
|
|
|
@ -35,10 +35,11 @@ def weights_init_classifier(m):
|
|||
class Baseline(nn.Module):
|
||||
in_planes = 2048
|
||||
|
||||
def __init__(self, num_classes, last_stride, model_path):
|
||||
def __init__(self, num_classes, last_stride, model_path=None):
|
||||
super(Baseline, self).__init__()
|
||||
self.base = ResNet(last_stride)
|
||||
self.base.load_param(model_path)
|
||||
if model_path is not None:
|
||||
self.base.load_param(model_path)
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.num_classes = num_classes
|
||||
|
||||
|
|
|
@ -11,5 +11,4 @@ from config import cfg
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = get_data_bunch(cfg)
|
||||
from IPython import embed; embed()
|
||||
data = get_data_bunch(cfg)
|
|
@ -77,16 +77,10 @@ def main():
|
|||
logger = setup_logger("reid_baseline", cfg.OUTPUT_DIR, 0)
|
||||
logger.info("Using {} GPUs.".format(num_gpus))
|
||||
logger.info(args)
|
||||
# with log_path.open('w') as f: f.write('{}\n'.format(args))
|
||||
# print(args)
|
||||
|
||||
if args.config_file != "":
|
||||
logger.info("Loaded configuration file {}".format(args.config_file))
|
||||
# with open(args.config_file, 'r') as cf:
|
||||
# config_str = "\n" + cf.read()
|
||||
# logger.info(config_str)
|
||||
logger.info("Running with config:\n{}".format(cfg))
|
||||
# with log_path.open('a') as f: f.write('{}\n'.format(cfg))
|
||||
|
||||
cudnn.benchmark = True
|
||||
train(cfg)
|
||||
|
|
Loading…
Reference in New Issue