mirror of https://github.com/JDAI-CV/fast-reid.git
save class number to config (#281)
Summary: Save the class number calculated based on datasets to the config file. If you hard-code the class number, make it unchanged.pull/365/head
parent
7e9a4775da
commit
2724515fd9
|
@ -472,8 +472,19 @@ class DefaultTrainer(SimpleTrainer):
|
|||
frozen = cfg.is_frozen()
|
||||
cfg.defrost()
|
||||
|
||||
# If you don't hard-code the number of classes, it will compute the number automatically
|
||||
if cfg.MODEL.HEADS.NUM_CLASSES == 0:
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
|
||||
# Update the saved config file to make the number of classes valid
|
||||
if comm.is_main_process() and output_dir:
|
||||
# Note: some of our scripts may expect the existence of
|
||||
# config.yaml in output directory
|
||||
path = os.path.join(output_dir, "config.yaml")
|
||||
with PathManager.open(path, "w") as f:
|
||||
f.write(cfg.dump())
|
||||
|
||||
iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
|
||||
cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
|
||||
cfg.SOLVER.MAX_ITER *= iters_per_epoch
|
||||
cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
|
||||
cfg.SOLVER.FREEZE_ITERS *= iters_per_epoch
|
||||
|
|
Loading…
Reference in New Issue