add check_cfg()

This commit is contained in:
KaiyangZhou 2020-09-14 10:34:18 +01:00
parent fefa31dc3f
commit c4714fae69

View File

@ -88,6 +88,12 @@ def reset_config(cfg, args):
cfg.data.transforms = args.transforms
def check_cfg(cfg):
if cfg.loss.name == 'triplet' and cfg.loss.triplet.weight_x == 0:
assert cfg.train.fixbase_epoch == 0, \
'The output of classifier is not included in the computational graph'
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -130,6 +136,7 @@ def main():
reset_config(cfg, args)
cfg.merge_from_list(args.opts)
set_random_seed(cfg.train.seed)
check_cfg(cfg)
log_name = 'test.log' if cfg.test.evaluate else 'train.log'
log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')