mirror of https://github.com/JDAI-CV/fast-reid.git
add kwargs for convenient dataset parameters passing (#290)
Summary: make it more easy for passing dataset kwargs through `build_reid_train_loader`pull/380/head
parent
0e1b91f74a
commit
8083547613
|
@ -19,12 +19,12 @@ from .transforms import build_transforms
|
|||
_root = os.getenv("FASTREID_DATASETS", "datasets")
|
||||
|
||||
|
||||
def build_reid_train_loader(cfg, mapper=None):
|
||||
def build_reid_train_loader(cfg, mapper=None, **kwargs):
|
||||
cfg = cfg.clone()
|
||||
|
||||
train_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
|
||||
dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL, **kwargs)
|
||||
if comm.is_main_process():
|
||||
dataset.show_train()
|
||||
train_items.extend(dataset.train)
|
||||
|
@ -59,10 +59,10 @@ def build_reid_train_loader(cfg, mapper=None):
|
|||
return train_loader
|
||||
|
||||
|
||||
def build_reid_test_loader(cfg, dataset_name):
|
||||
def build_reid_test_loader(cfg, dataset_name, **kwargs):
|
||||
cfg = cfg.clone()
|
||||
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
|
||||
if comm.is_main_process():
|
||||
dataset.show_test()
|
||||
test_items = dataset.query + dataset.gallery
|
||||
|
|
Loading…
Reference in New Issue