mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
add check_cfg()
This commit is contained in:
parent
fefa31dc3f
commit
c4714fae69
@ -88,6 +88,12 @@ def reset_config(cfg, args):
|
||||
cfg.data.transforms = args.transforms
|
||||
|
||||
|
||||
def check_cfg(cfg):
|
||||
if cfg.loss.name == 'triplet' and cfg.loss.triplet.weight_x == 0:
|
||||
assert cfg.train.fixbase_epoch == 0, \
|
||||
'The output of classifier is not included in the computational graph'
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -130,6 +136,7 @@ def main():
|
||||
reset_config(cfg, args)
|
||||
cfg.merge_from_list(args.opts)
|
||||
set_random_seed(cfg.train.seed)
|
||||
check_cfg(cfg)
|
||||
|
||||
log_name = 'test.log' if cfg.test.evaluate else 'train.log'
|
||||
log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')
|
||||
|
Loading…
x
Reference in New Issue
Block a user