add kwargs for convenient dataset parameters passing (#290)

Summary: make it more easy for passing dataset kwargs through `build_reid_train_loader`
pull/380/head
liaoxingyu 2020-12-28 14:39:08 +08:00
parent 0e1b91f74a
commit 8083547613
1 changed files with 4 additions and 4 deletions

View File

@ -19,12 +19,12 @@ from .transforms import build_transforms
_root = os.getenv("FASTREID_DATASETS", "datasets")
def build_reid_train_loader(cfg, mapper=None):
def build_reid_train_loader(cfg, mapper=None, **kwargs):
cfg = cfg.clone()
train_items = list()
for d in cfg.DATASETS.NAMES:
dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL, **kwargs)
if comm.is_main_process():
dataset.show_train()
train_items.extend(dataset.train)
@ -59,10 +59,10 @@ def build_reid_train_loader(cfg, mapper=None):
return train_loader
def build_reid_test_loader(cfg, dataset_name):
def build_reid_test_loader(cfg, dataset_name, **kwargs):
cfg = cfg.clone()
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
if comm.is_main_process():
dataset.show_test()
test_items = dataset.query + dataset.gallery