mirror of https://github.com/JDAI-CV/fast-reid.git
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import sys
|
|
sys.path.append('.')
|
|
from data import get_dataloader
|
|
from config import cfg
|
|
import argparse
|
|
from data.datasets import init_dataset
|
|
# cfg.DATALOADER.SAMPLER = 'triplet'
|
|
cfg.DATASETS.NAMES = ("market1501", "dukemtmc", "cuhk03", "msmt17",)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="ReID Baseline Training")
|
|
parser.add_argument(
|
|
'-cfg', "--config_file",
|
|
default="",
|
|
metavar="FILE",
|
|
help="path to config file",
|
|
type=str
|
|
)
|
|
# parser.add_argument("--local_rank", type=int, default=0)
|
|
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
|
|
nargs=argparse.REMAINDER)
|
|
args = parser.parse_args()
|
|
cfg.merge_from_list(args.opts)
|
|
|
|
# dataset = init_dataset('msmt17', combineall=True)
|
|
get_dataloader(cfg)
|
|
# tng_dataloader, val_dataloader, num_classes, num_query = get_dataloader(cfg)
|
|
# def get_ex(): return open_image('datasets/beijingStation/query/000245_c10s2_1561732033722.000000.jpg')
|
|
# im = get_ex()
|
|
# print(data.train_ds[0])
|
|
# print(data.test_ds[0])
|
|
# a = next(iter(data.train_dl))
|
|
# from IPython import embed; embed()
|
|
# from ipdb import set_trace; set_trace()
|
|
# im.apply_tfms(crop_pad(size=(300, 300)))
|