diff --git a/train.py b/train.py index afd54a6f..3811f7b4 100644 --- a/train.py +++ b/train.py @@ -42,8 +42,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument('--pretrained', action='store_true', default=False, help='Start with pretrained version of specified network (if avail)') -parser.add_argument('--img-size', type=int, default=224, metavar='N', - help='Image patch size (default: 224)') +parser.add_argument('--img-size', type=int, default=None, metavar='N', + help='Image patch size (default: None => model default)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', @@ -159,15 +159,6 @@ def main(): torch.manual_seed(args.seed + args.rank) - output_dir = '' - if args.local_rank == 0: - output_base = args.output if args.output else './output' - exp_name = '-'.join([ - datetime.now().strftime("%Y%m%d-%H%M%S"), - args.model, - str(args.img_size)]) - output_dir = get_outdir(output_base, 'train', exp_name) - model = create_model( args.model, pretrained=args.pretrained, @@ -291,13 +282,21 @@ def main(): validate_loss_fn = train_loss_fn eval_metric = args.eval_metric - saver = None - if output_dir: - # only set if process is rank 0 - decreasing = True if eval_metric == 'loss' else False - saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) best_metric = None best_epoch = None + saver = None + output_dir = '' + if args.local_rank == 0: + output_base = args.output if args.output else './output' + exp_name = '-'.join([ + datetime.now().strftime("%Y%m%d-%H%M%S"), + args.model, + str(data_config['input_size'][-1]) + ]) + output_dir = get_outdir(output_base, 'train', exp_name) + decreasing = True if eval_metric == 'loss' else False + saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) + try: for epoch in range(start_epoch, num_epochs): if args.distributed: diff --git a/utils.py b/utils.py index 43f58117..626ae9dc 100644 --- a/utils.py +++ b/utils.py @@ -253,9 +253,9 @@ class ModelEma: name = k new_state_dict[name] = v self.ema.load_state_dict(new_state_dict) - print("=> loaded state_dict_ema") + print("=> Loaded state_dict_ema") else: - print("=> failed to find state_dict_ema, starting from loaded model weights)") + print("=> Failed to find state_dict_ema, starting from loaded model weights") def update(self, model): # correct a mismatch in state dict keys