Finish multi-dataset training
pull/25/head
liaoxingyu 2019-07-30 18:56:43 +08:00
parent defc190f6b
commit 08d105560c
8 changed files with 26 additions and 16 deletions

2
.gitignore vendored
View File

@ -4,3 +4,5 @@ __pycache__
.vscode
datasets
csrc/eval_cylib/build/
logs/
*.ipynb

View File

@ -35,6 +35,6 @@ TEST:
IMS_PER_BATCH: 512
WEIGHT: "path"
OUTPUT_DIR: "logs/co-train/test"
OUTPUT_DIR: "logs/beijing/market+duke+bj/"

View File

@ -23,7 +23,13 @@ def get_data_bunch(cfg):
)
def _process_dir(dir_path):
img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
img_paths = []
if 'beijingStation' in dir_path:
id_dirs = os.listdir(dir_path)
for d in id_dirs:
img_paths.extend(glob.glob(os.path.join(dir_path, d, '*.jpg')))
else:
img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d)')
pid_container = set()
@ -37,12 +43,14 @@ def get_data_bunch(cfg):
market_train_path = 'datasets/Market-1501-v15.09.15/bounding_box_train'
duke_train_path = 'datasets/DukeMTMC-reID/bounding_box_train'
cuhk03_train_path = ''
cuhk03_train_path = 'datasets/cuhk03/'
bjStation_train_path = 'datasets/beijingStation/20190720/train'
query_path = 'datasets/Market-1501-v15.09.15/query'
gallery_path = 'datasets/Market-1501-v15.09.15/bounding_box_test'
# train_img_names = _process_dir(market_train_path) + _process_dir(duke_train_path)
train_img_names = _process_dir(market_train_path)
train_img_names = _process_dir(market_train_path) + _process_dir(duke_train_path) + _process_dir(bjStation_train_path)
# train_img_names = CUHK03().train
train_names = [i[0] for i in train_img_names]
query_names = _process_dir(query_path)
@ -55,6 +63,10 @@ def get_data_bunch(cfg):
def get_labels(file_path):
""" Suitable for muilti-dataset training """
# if 'cuhk03' in file_path:
# prefix = 'cuhk'
# pid = file_path.split('/')[-1].split('_')[1]
# else:
prefix = file_path.split('/')[1]
pat = re.compile(r'([-\d]+)_c(\d)')
pid, _ = pat.search(file_path).groups()
@ -63,6 +75,7 @@ def get_data_bunch(cfg):
data_bunch = ImageDataBunch.from_name_func('datasets', train_names, label_func=get_labels, valid_pct=0,
size=(256, 128), ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
val_bs=cfg.TEST.IMS_PER_BATCH)
if 'triplet' in cfg.DATALOADER.SAMPLER:
data_sampler = RandomIdentitySampler(train_names, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
data_bunch.train_dl = data_bunch.train_dl.new(shuffle=False, sampler=data_sampler)

View File

@ -31,7 +31,7 @@ class CUHK03(BaseImageDataset):
"""
dataset_dir = 'cuhk03'
def __init__(self, root='/export/home/lxy/DATA/reid', split_id=0, cuhk03_labeled=False,
def __init__(self, root='datasets', split_id=0, cuhk03_labeled=False,
cuhk03_classic_split=False, verbose=True,
**kwargs):
super(CUHK03, self).__init__()

View File

@ -81,6 +81,7 @@ class TestModel(LearnerCallback):
self._logger.info("mAP: {:.1%}".format(mAP))
for r in [1, 5, 10]:
self._logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
self.learn.save("model_{}".format(epoch))
def do_train(
cfg,

View File

@ -35,10 +35,11 @@ def weights_init_classifier(m):
class Baseline(nn.Module):
in_planes = 2048
def __init__(self, num_classes, last_stride, model_path):
def __init__(self, num_classes, last_stride, model_path=None):
super(Baseline, self).__init__()
self.base = ResNet(last_stride)
self.base.load_param(model_path)
if model_path is not None:
self.base.load_param(model_path)
self.gap = nn.AdaptiveAvgPool2d(1)
self.num_classes = num_classes

View File

@ -11,5 +11,4 @@ from config import cfg
if __name__ == '__main__':
data = get_data_bunch(cfg)
from IPython import embed; embed()
data = get_data_bunch(cfg)

View File

@ -77,16 +77,10 @@ def main():
logger = setup_logger("reid_baseline", cfg.OUTPUT_DIR, 0)
logger.info("Using {} GPUs.".format(num_gpus))
logger.info(args)
# with log_path.open('w') as f: f.write('{}\n'.format(args))
# print(args)
if args.config_file != "":
logger.info("Loaded configuration file {}".format(args.config_file))
# with open(args.config_file, 'r') as cf:
# config_str = "\n" + cf.read()
# logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
# with log_path.open('a') as f: f.write('{}\n'.format(cfg))
cudnn.benchmark = True
train(cfg)