mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
Add option to finetune on larger resolution (#43)
* Add option for finetuning a model * Fixes * Keep model in eval mode during finetuning * Only skip head weights if size mismatch * Remove finetune-epochs Might not be needed * Raise error if distillation + finetune are enabled
This commit is contained in:
parent
d9932c08b5
commit
a8e90967a3
@ -19,9 +19,9 @@ import utils
|
||||
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
|
||||
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
||||
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
||||
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None):
|
||||
# TODO fix this for finetuning
|
||||
model.train()
|
||||
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
|
||||
set_training_mode=True):
|
||||
model.train(set_training_mode)
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
header = 'Epoch: [{}]'.format(epoch)
|
||||
|
47
main.py
47
main.py
@ -132,6 +132,9 @@ def get_args_parser():
|
||||
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
||||
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
||||
|
||||
# * Finetuning params
|
||||
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
||||
help='dataset path')
|
||||
@ -170,6 +173,9 @@ def main(args):
|
||||
|
||||
print(args)
|
||||
|
||||
if args.distillation_type != 'none' and args.finetune:
|
||||
raise NotImplementedError("Finetuning with distillation not yet supported")
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# fix the seed for reproducibility
|
||||
@ -241,7 +247,41 @@ def main(args):
|
||||
drop_block_rate=None,
|
||||
)
|
||||
|
||||
# TODO: finetuning
|
||||
if args.finetune:
|
||||
if args.finetune.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.finetune, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.finetune, map_location='cpu')
|
||||
|
||||
checkpoint_model = checkpoint['model']
|
||||
state_dict = model.state_dict()
|
||||
for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
|
||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
||||
print(f"Removing key {k} from pretrained checkpoint")
|
||||
del checkpoint_model[k]
|
||||
|
||||
# interpolate position embedding
|
||||
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches ** 0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
checkpoint_model['pos_embed'] = new_pos_embed
|
||||
|
||||
model.load_state_dict(checkpoint_model, strict=False)
|
||||
|
||||
model.to(device)
|
||||
|
||||
@ -323,7 +363,7 @@ def main(args):
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
return
|
||||
|
||||
print("Start training")
|
||||
print(f"Start training for {args.epochs} epochs")
|
||||
start_time = time.time()
|
||||
max_accuracy = 0.0
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
@ -333,7 +373,8 @@ def main(args):
|
||||
train_stats = train_one_epoch(
|
||||
model, criterion, data_loader_train,
|
||||
optimizer, device, epoch, loss_scaler,
|
||||
args.clip_grad, model_ema, mixup_fn
|
||||
args.clip_grad, model_ema, mixup_fn,
|
||||
set_training_mode=args.finetune == '' # keep in eval mode during finetuning
|
||||
)
|
||||
|
||||
lr_scheduler.step(epoch)
|
||||
|
15
models.py
15
models.py
@ -139,3 +139,18 @@ def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="",
|
||||
map_location="cpu", check_hash=True
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user