From 0f69dbe9028056ebfbcbc9928fece4d49b1572af Mon Sep 17 00:00:00 2001 From: Chen Jiayu <38110862+tuofeilunhifi@users.noreply.github.com> Date: Thu, 16 Jun 2022 14:46:47 +0800 Subject: [PATCH] Swint_cifar config bugfix (#95) --- .../cifar10/swintiny_b64_5e_jpg.py | 4 +- .../backbones/pytorch_image_models_wrapper.py | 40 ++++++++++--------- easycv/models/modelzoo.py | 4 ++ 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/configs/classification/cifar10/swintiny_b64_5e_jpg.py b/configs/classification/cifar10/swintiny_b64_5e_jpg.py index dad3629a..bbf91017 100644 --- a/configs/classification/cifar10/swintiny_b64_5e_jpg.py +++ b/configs/classification/cifar10/swintiny_b64_5e_jpg.py @@ -5,9 +5,9 @@ model = dict( backbone=dict( type='PytorchImageModelWrapper', model_name='swin_tiny_patch4_window7_224', - num_classes=10, + num_classes=0, ), - head=dict(type='ClsHead', with_fc=False)) + head=dict(type='ClsHead', in_channels=768, with_fc=True, num_classes=10)) # dataset settings class_list = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', diff --git a/easycv/models/backbones/pytorch_image_models_wrapper.py b/easycv/models/backbones/pytorch_image_models_wrapper.py index 8d838c1a..6b141489 100644 --- a/easycv/models/backbones/pytorch_image_models_wrapper.py +++ b/easycv/models/backbones/pytorch_image_models_wrapper.py @@ -111,25 +111,29 @@ class PytorchImageModelWrapper(nn.Module): logger = get_root_logger() if pretrained: if self.model_name in self.timm_model_names: - default_pretrained_model_path = model_urls[self.model_name] - print_log( - 'load model from default path: {}'.format( - default_pretrained_model_path), logger) - if default_pretrained_model_path.endswith('.npz'): - pretrained_loc = download_cached_file( - default_pretrained_model_path, - check_hash=False, - progress=False) - return self.model.load_pretrained(pretrained_loc) + if self.model_name in model_urls: + default_pretrained_model_path = model_urls[self.model_name] + print_log( + 'load model from default path: {}'.format( + default_pretrained_model_path), logger) + if default_pretrained_model_path.endswith('.npz'): + pretrained_loc = download_cached_file( + default_pretrained_model_path, + check_hash=False, + progress=False) + return self.model.load_pretrained(pretrained_loc) + else: + backbone_module = importlib.import_module( + self.model.__module__) + return load_pretrained( + self.model, + default_cfg={'url': default_pretrained_model_path}, + filter_fn=backbone_module.checkpoint_filter_fn + if hasattr(backbone_module, 'checkpoint_filter_fn') + else None, + strict=False) else: - backbone_module = importlib.import_module( - self.model.__module__) - return load_pretrained( - self.model, - default_cfg={'url': default_pretrained_model_path}, - filter_fn=backbone_module.checkpoint_filter_fn - if hasattr(backbone_module, 'checkpoint_filter_fn') - else None) + logger.warning('pretrained model for model_name not found') elif self.model_name in _MODEL_MAP: if self.model_name in model_urls.keys(): default_pretrained_model_path = model_urls[self.model_name] diff --git a/easycv/models/modelzoo.py b/easycv/models/modelzoo.py index e85d353c..4b997b1b 100644 --- a/easycv/models/modelzoo.py +++ b/easycv/models/modelzoo.py @@ -94,6 +94,10 @@ timm_models = { 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_patch16_224-b5f2ef4d.pth', 'deit_base_distilled_patch16_224': 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_distilled_patch16_224-df68dfff.pth', + 'swin_tiny_patch4_window7_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_tiny_patch4_window7_224.pth', + 'swin_small_patch4_window7_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_small_patch4_window7_224.pth', 'swin_base_patch4_window7_224': 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_base_patch4_window7_224_22kto1k.pth', 'swin_large_patch4_window7_224':