Add option to finetune on larger resolution (#43)

* Add option for finetuning a model

* Fixes

* Keep model in eval mode during finetuning

* Only skip head weights if size mismatch

* Remove finetune-epochs

Might not be needed

* Raise error if distillation + finetune are enabled
This commit is contained in:
Francisco Massa 2021-01-15 10:13:52 +01:00 committed by GitHub
parent d9932c08b5
commit a8e90967a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 6 deletions

View File

@ -19,9 +19,9 @@ import utils
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None):
# TODO fix this for finetuning
model.train()
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
set_training_mode=True):
model.train(set_training_mode)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)

47
main.py
View File

@ -132,6 +132,9 @@ def get_args_parser():
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
# * Finetuning params
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
# Dataset parameters
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
@ -170,6 +173,9 @@ def main(args):
print(args)
if args.distillation_type != 'none' and args.finetune:
raise NotImplementedError("Finetuning with distillation not yet supported")
device = torch.device(args.device)
# fix the seed for reproducibility
@ -241,7 +247,41 @@ def main(args):
drop_block_rate=None,
)
# TODO: finetuning
if args.finetune:
if args.finetune.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.finetune, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.finetune, map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
model.load_state_dict(checkpoint_model, strict=False)
model.to(device)
@ -323,7 +363,7 @@ def main(args):
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return
print("Start training")
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):
@ -333,7 +373,8 @@ def main(args):
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, model_ema, mixup_fn
args.clip_grad, model_ema, mixup_fn,
set_training_mode=args.finetune == '' # keep in eval mode during finetuning
)
lr_scheduler.step(epoch)

View File

@ -139,3 +139,18 @@ def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model