mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Default to img_size in model default_cfg, defer output folder creation until later in the init sequence
This commit is contained in:
parent
9bcd65181b
commit
7dab6d1ec7
31
train.py
31
train.py
@ -42,8 +42,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|||||||
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
||||||
parser.add_argument('--pretrained', action='store_true', default=False,
|
parser.add_argument('--pretrained', action='store_true', default=False,
|
||||||
help='Start with pretrained version of specified network (if avail)')
|
help='Start with pretrained version of specified network (if avail)')
|
||||||
parser.add_argument('--img-size', type=int, default=224, metavar='N',
|
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
||||||
help='Image patch size (default: 224)')
|
help='Image patch size (default: None => model default)')
|
||||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
help='Override mean pixel value of dataset')
|
help='Override mean pixel value of dataset')
|
||||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||||
@ -159,15 +159,6 @@ def main():
|
|||||||
|
|
||||||
torch.manual_seed(args.seed + args.rank)
|
torch.manual_seed(args.seed + args.rank)
|
||||||
|
|
||||||
output_dir = ''
|
|
||||||
if args.local_rank == 0:
|
|
||||||
output_base = args.output if args.output else './output'
|
|
||||||
exp_name = '-'.join([
|
|
||||||
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
||||||
args.model,
|
|
||||||
str(args.img_size)])
|
|
||||||
output_dir = get_outdir(output_base, 'train', exp_name)
|
|
||||||
|
|
||||||
model = create_model(
|
model = create_model(
|
||||||
args.model,
|
args.model,
|
||||||
pretrained=args.pretrained,
|
pretrained=args.pretrained,
|
||||||
@ -291,13 +282,21 @@ def main():
|
|||||||
validate_loss_fn = train_loss_fn
|
validate_loss_fn = train_loss_fn
|
||||||
|
|
||||||
eval_metric = args.eval_metric
|
eval_metric = args.eval_metric
|
||||||
saver = None
|
|
||||||
if output_dir:
|
|
||||||
# only set if process is rank 0
|
|
||||||
decreasing = True if eval_metric == 'loss' else False
|
|
||||||
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
|
||||||
best_metric = None
|
best_metric = None
|
||||||
best_epoch = None
|
best_epoch = None
|
||||||
|
saver = None
|
||||||
|
output_dir = ''
|
||||||
|
if args.local_rank == 0:
|
||||||
|
output_base = args.output if args.output else './output'
|
||||||
|
exp_name = '-'.join([
|
||||||
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
||||||
|
args.model,
|
||||||
|
str(data_config['input_size'][-1])
|
||||||
|
])
|
||||||
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
||||||
|
decreasing = True if eval_metric == 'loss' else False
|
||||||
|
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(start_epoch, num_epochs):
|
for epoch in range(start_epoch, num_epochs):
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
|
4
utils.py
4
utils.py
@ -253,9 +253,9 @@ class ModelEma:
|
|||||||
name = k
|
name = k
|
||||||
new_state_dict[name] = v
|
new_state_dict[name] = v
|
||||||
self.ema.load_state_dict(new_state_dict)
|
self.ema.load_state_dict(new_state_dict)
|
||||||
print("=> loaded state_dict_ema")
|
print("=> Loaded state_dict_ema")
|
||||||
else:
|
else:
|
||||||
print("=> failed to find state_dict_ema, starting from loaded model weights)")
|
print("=> Failed to find state_dict_ema, starting from loaded model weights")
|
||||||
|
|
||||||
def update(self, model):
|
def update(self, model):
|
||||||
# correct a mismatch in state dict keys
|
# correct a mismatch in state dict keys
|
||||||
|
Loading…
x
Reference in New Issue
Block a user