2019-01-10 18:39:31 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
2019-04-21 13:38:55 +08:00
|
|
|
@author: l1aoxingyu
|
2019-01-10 18:39:31 +08:00
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
|
|
|
|
2019-04-21 13:38:55 +08:00
|
|
|
import glob
|
2019-07-23 09:46:37 +08:00
|
|
|
import os
|
|
|
|
import re
|
2019-01-10 18:39:31 +08:00
|
|
|
|
2019-04-21 13:38:55 +08:00
|
|
|
from fastai.vision import *
|
2019-08-07 16:54:50 +08:00
|
|
|
|
2019-08-01 15:49:31 +08:00
|
|
|
from .datasets import CUHK03
|
2019-08-07 16:54:50 +08:00
|
|
|
from .samplers import RandomIdentitySampler
|
|
|
|
from .transforms import build_transforms
|
2019-04-21 13:38:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
def get_data_bunch(cfg):
|
2019-08-07 16:54:50 +08:00
|
|
|
ds_tfms = build_transforms(cfg)
|
2019-04-21 13:38:55 +08:00
|
|
|
|
2019-08-14 14:50:44 +08:00
|
|
|
def _process_dir(dir_path):
|
2019-07-30 18:56:43 +08:00
|
|
|
img_paths = []
|
2019-08-14 14:50:44 +08:00
|
|
|
img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
|
2019-07-31 21:11:54 +08:00
|
|
|
pattern = re.compile(r'([-\d]+)_c(\d*)')
|
2019-04-21 13:38:55 +08:00
|
|
|
v_paths = []
|
|
|
|
for img_path in img_paths:
|
|
|
|
pid, camid = map(int, pattern.search(img_path).groups())
|
2019-07-31 21:11:54 +08:00
|
|
|
pid = int(pid)
|
2019-04-21 13:38:55 +08:00
|
|
|
if pid == -1: continue # junk images are just ignored
|
2019-07-31 21:11:54 +08:00
|
|
|
v_paths.append([img_path,pid,camid])
|
2019-04-21 13:38:55 +08:00
|
|
|
return v_paths
|
|
|
|
|
|
|
|
market_train_path = 'datasets/Market-1501-v15.09.15/bounding_box_train'
|
|
|
|
duke_train_path = 'datasets/DukeMTMC-reID/bounding_box_train'
|
2019-07-30 18:56:43 +08:00
|
|
|
cuhk03_train_path = 'datasets/cuhk03/'
|
|
|
|
|
2019-07-31 21:11:54 +08:00
|
|
|
market_query_path = 'datasets/Market-1501-v15.09.15/query'
|
|
|
|
marker_gallery_path = 'datasets/Market-1501-v15.09.15/bounding_box_test'
|
2019-04-21 13:38:55 +08:00
|
|
|
|
2019-07-31 21:11:54 +08:00
|
|
|
train_img_names = list()
|
|
|
|
for d in cfg.DATASETS.NAMES:
|
|
|
|
if d == 'market1501':
|
|
|
|
train_img_names.extend(_process_dir(market_train_path))
|
|
|
|
elif d == 'duke':
|
|
|
|
train_img_names.extend(_process_dir(duke_train_path))
|
2019-08-01 15:49:31 +08:00
|
|
|
elif d == 'cuhk03':
|
|
|
|
train_img_names.extend(CUHK03().train)
|
2019-07-31 21:11:54 +08:00
|
|
|
else:
|
2019-08-14 14:50:44 +08:00
|
|
|
raise NameError(f'{d} is not available')
|
2019-07-31 21:11:54 +08:00
|
|
|
|
2019-04-21 13:38:55 +08:00
|
|
|
train_names = [i[0] for i in train_img_names]
|
|
|
|
|
2019-08-13 13:52:25 +08:00
|
|
|
if cfg.DATASETS.TEST_NAMES == "market1501":
|
|
|
|
query_names = _process_dir(market_query_path)
|
|
|
|
gallery_names = _process_dir(marker_gallery_path)
|
|
|
|
else:
|
|
|
|
print(f"not support {cfg.DATASETS.TEST_NAMES} test set")
|
2019-07-31 21:11:54 +08:00
|
|
|
|
2019-04-21 13:38:55 +08:00
|
|
|
test_fnames = []
|
|
|
|
test_labels = []
|
|
|
|
for i in query_names+gallery_names:
|
|
|
|
test_fnames.append(i[0])
|
|
|
|
test_labels.append(i[1:])
|
|
|
|
|
|
|
|
def get_labels(file_path):
|
|
|
|
""" Suitable for muilti-dataset training """
|
2019-08-01 15:49:31 +08:00
|
|
|
if 'cuhk03' in file_path:
|
|
|
|
prefix = 'cuhk'
|
|
|
|
pid = '_'.join(file_path.split('/')[-1].split('_')[0:2])
|
|
|
|
else:
|
|
|
|
prefix = file_path.split('/')[1]
|
|
|
|
pat = re.compile(r'([-\d]+)_c(\d)')
|
|
|
|
pid, _ = pat.search(file_path).groups()
|
2019-04-21 13:38:55 +08:00
|
|
|
return prefix + '_' + pid
|
|
|
|
|
|
|
|
data_bunch = ImageDataBunch.from_name_func('datasets', train_names, label_func=get_labels, valid_pct=0,
|
2019-08-01 15:49:31 +08:00
|
|
|
size=cfg.INPUT.SIZE_TRAIN, ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
|
2019-04-23 12:56:38 +08:00
|
|
|
val_bs=cfg.TEST.IMS_PER_BATCH)
|
2019-07-30 18:56:43 +08:00
|
|
|
|
2019-04-23 12:56:38 +08:00
|
|
|
if 'triplet' in cfg.DATALOADER.SAMPLER:
|
|
|
|
data_sampler = RandomIdentitySampler(train_names, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
|
|
|
|
data_bunch.train_dl = data_bunch.train_dl.new(shuffle=False, sampler=data_sampler)
|
|
|
|
|
2019-04-21 13:38:55 +08:00
|
|
|
data_bunch.add_test(test_fnames)
|
|
|
|
data_bunch.normalize(imagenet_stats)
|
|
|
|
|
|
|
|
return data_bunch, test_labels, len(query_names)
|