save class number to config (#281)

Summary: Save the class number calculated based on datasets to the config file. If you hard-code the class number, make it unchanged.
pull/365/head
liaoxingyu 2020-11-06 16:07:37 +08:00
parent 7e9a4775da
commit 2724515fd9
1 changed files with 12 additions and 1 deletions

View File

@ -472,8 +472,19 @@ class DefaultTrainer(SimpleTrainer):
frozen = cfg.is_frozen()
cfg.defrost()
# If you don't hard-code the number of classes, it will compute the number automatically
if cfg.MODEL.HEADS.NUM_CLASSES == 0:
output_dir = cfg.OUTPUT_DIR
cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
# Update the saved config file to make the number of classes valid
if comm.is_main_process() and output_dir:
# Note: some of our scripts may expect the existence of
# config.yaml in output directory
path = os.path.join(output_dir, "config.yaml")
with PathManager.open(path, "w") as f:
f.write(cfg.dump())
iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
cfg.SOLVER.MAX_ITER *= iters_per_epoch
cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
cfg.SOLVER.FREEZE_ITERS *= iters_per_epoch