diff --git a/fastreid/data/build.py b/fastreid/data/build.py index fad5397..92202b2 100644 --- a/fastreid/data/build.py +++ b/fastreid/data/build.py @@ -19,6 +19,9 @@ _root = os.getenv("FASTREID_DATASETS", "datasets") def build_reid_train_loader(cfg): + cfg = cfg.clone() + cfg.defrost() + train_items = list() for d in cfg.DATASETS.NAMES: dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL) @@ -26,7 +29,6 @@ def build_reid_train_loader(cfg): dataset.show_train() train_items.extend(dataset.train) - cfg.defrost() iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH cfg.SOLVER.MAX_ITER *= iters_per_epoch train_transforms = build_transforms(cfg, is_train=True) @@ -57,6 +59,9 @@ def build_reid_train_loader(cfg): def build_reid_test_loader(cfg, dataset_name): + cfg = cfg.clone() + cfg.defrost() + dataset = DATASET_REGISTRY.get(dataset_name)(root=_root) if comm.is_main_process(): dataset.show_test()