From 2724515fd989832ce2a1e4af05ba755dfd292d51 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Fri, 6 Nov 2020 16:07:37 +0800 Subject: [PATCH] save class number to config (#281) Summary: Save the class number calculated based on datasets to the config file. If you hard-code the class number, make it unchanged. --- fastreid/engine/defaults.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index 8a2dfa6..69482be 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -472,8 +472,19 @@ class DefaultTrainer(SimpleTrainer): frozen = cfg.is_frozen() cfg.defrost() + # If you don't hard-code the number of classes, it will compute the number automatically + if cfg.MODEL.HEADS.NUM_CLASSES == 0: + output_dir = cfg.OUTPUT_DIR + cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes + # Update the saved config file to make the number of classes valid + if comm.is_main_process() and output_dir: + # Note: some of our scripts may expect the existence of + # config.yaml in output directory + path = os.path.join(output_dir, "config.yaml") + with PathManager.open(path, "w") as f: + f.write(cfg.dump()) + iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH - cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes cfg.SOLVER.MAX_ITER *= iters_per_epoch cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch cfg.SOLVER.FREEZE_ITERS *= iters_per_epoch