mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Default to img_size in model default_cfg, defer output folder creation until later in the init sequence
This commit is contained in:
parent
9bcd65181b
commit
7dab6d1ec7
31
train.py
31
train.py
@ -42,8 +42,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
||||
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
||||
parser.add_argument('--pretrained', action='store_true', default=False,
|
||||
help='Start with pretrained version of specified network (if avail)')
|
||||
parser.add_argument('--img-size', type=int, default=224, metavar='N',
|
||||
help='Image patch size (default: 224)')
|
||||
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
||||
help='Image patch size (default: None => model default)')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
@ -159,15 +159,6 @@ def main():
|
||||
|
||||
torch.manual_seed(args.seed + args.rank)
|
||||
|
||||
output_dir = ''
|
||||
if args.local_rank == 0:
|
||||
output_base = args.output if args.output else './output'
|
||||
exp_name = '-'.join([
|
||||
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
||||
args.model,
|
||||
str(args.img_size)])
|
||||
output_dir = get_outdir(output_base, 'train', exp_name)
|
||||
|
||||
model = create_model(
|
||||
args.model,
|
||||
pretrained=args.pretrained,
|
||||
@ -291,13 +282,21 @@ def main():
|
||||
validate_loss_fn = train_loss_fn
|
||||
|
||||
eval_metric = args.eval_metric
|
||||
saver = None
|
||||
if output_dir:
|
||||
# only set if process is rank 0
|
||||
decreasing = True if eval_metric == 'loss' else False
|
||||
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
||||
best_metric = None
|
||||
best_epoch = None
|
||||
saver = None
|
||||
output_dir = ''
|
||||
if args.local_rank == 0:
|
||||
output_base = args.output if args.output else './output'
|
||||
exp_name = '-'.join([
|
||||
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
||||
args.model,
|
||||
str(data_config['input_size'][-1])
|
||||
])
|
||||
output_dir = get_outdir(output_base, 'train', exp_name)
|
||||
decreasing = True if eval_metric == 'loss' else False
|
||||
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
||||
|
||||
try:
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
if args.distributed:
|
||||
|
4
utils.py
4
utils.py
@ -253,9 +253,9 @@ class ModelEma:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
self.ema.load_state_dict(new_state_dict)
|
||||
print("=> loaded state_dict_ema")
|
||||
print("=> Loaded state_dict_ema")
|
||||
else:
|
||||
print("=> failed to find state_dict_ema, starting from loaded model weights)")
|
||||
print("=> Failed to find state_dict_ema, starting from loaded model weights")
|
||||
|
||||
def update(self, model):
|
||||
# correct a mismatch in state dict keys
|
||||
|
Loading…
x
Reference in New Issue
Block a user