diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index 8a2dfa6..69482be 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -472,8 +472,19 @@ class DefaultTrainer(SimpleTrainer): frozen = cfg.is_frozen() cfg.defrost() + # If you don't hard-code the number of classes, it will compute the number automatically + if cfg.MODEL.HEADS.NUM_CLASSES == 0: + output_dir = cfg.OUTPUT_DIR + cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes + # Update the saved config file to make the number of classes valid + if comm.is_main_process() and output_dir: + # Note: some of our scripts may expect the existence of + # config.yaml in output directory + path = os.path.join(output_dir, "config.yaml") + with PathManager.open(path, "w") as f: + f.write(cfg.dump()) + iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH - cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes cfg.SOLVER.MAX_ITER *= iters_per_epoch cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch cfg.SOLVER.FREEZE_ITERS *= iters_per_epoch