diff --git a/scripts/main.py b/scripts/main.py index 05df742..61aa49d 100755 --- a/scripts/main.py +++ b/scripts/main.py @@ -88,6 +88,12 @@ def reset_config(cfg, args): cfg.data.transforms = args.transforms +def check_cfg(cfg): + if cfg.loss.name == 'triplet' and cfg.loss.triplet.weight_x == 0: + assert cfg.train.fixbase_epoch == 0, \ + 'The output of classifier is not included in the computational graph' + + def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -130,6 +136,7 @@ def main(): reset_config(cfg, args) cfg.merge_from_list(args.opts) set_random_seed(cfg.train.seed) + check_cfg(cfg) log_name = 'test.log' if cfg.test.evaluate else 'train.log' log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')