Swint_cifar config bugfix (#95)

pull/97/head
Chen Jiayu 2022-06-16 14:46:47 +08:00 committed by GitHub
parent b737027aa4
commit 0f69dbe902
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 20 deletions

View File

@ -5,9 +5,9 @@ model = dict(
backbone=dict(
type='PytorchImageModelWrapper',
model_name='swin_tiny_patch4_window7_224',
num_classes=10,
num_classes=0,
),
head=dict(type='ClsHead', with_fc=False))
head=dict(type='ClsHead', in_channels=768, with_fc=True, num_classes=10))
# dataset settings
class_list = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',

View File

@ -111,25 +111,29 @@ class PytorchImageModelWrapper(nn.Module):
logger = get_root_logger()
if pretrained:
if self.model_name in self.timm_model_names:
default_pretrained_model_path = model_urls[self.model_name]
print_log(
'load model from default path: {}'.format(
default_pretrained_model_path), logger)
if default_pretrained_model_path.endswith('.npz'):
pretrained_loc = download_cached_file(
default_pretrained_model_path,
check_hash=False,
progress=False)
return self.model.load_pretrained(pretrained_loc)
if self.model_name in model_urls:
default_pretrained_model_path = model_urls[self.model_name]
print_log(
'load model from default path: {}'.format(
default_pretrained_model_path), logger)
if default_pretrained_model_path.endswith('.npz'):
pretrained_loc = download_cached_file(
default_pretrained_model_path,
check_hash=False,
progress=False)
return self.model.load_pretrained(pretrained_loc)
else:
backbone_module = importlib.import_module(
self.model.__module__)
return load_pretrained(
self.model,
default_cfg={'url': default_pretrained_model_path},
filter_fn=backbone_module.checkpoint_filter_fn
if hasattr(backbone_module, 'checkpoint_filter_fn')
else None,
strict=False)
else:
backbone_module = importlib.import_module(
self.model.__module__)
return load_pretrained(
self.model,
default_cfg={'url': default_pretrained_model_path},
filter_fn=backbone_module.checkpoint_filter_fn
if hasattr(backbone_module, 'checkpoint_filter_fn')
else None)
logger.warning('pretrained model for model_name not found')
elif self.model_name in _MODEL_MAP:
if self.model_name in model_urls.keys():
default_pretrained_model_path = model_urls[self.model_name]

View File

@ -94,6 +94,10 @@ timm_models = {
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_patch16_224-b5f2ef4d.pth',
'deit_base_distilled_patch16_224':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_distilled_patch16_224-df68dfff.pth',
'swin_tiny_patch4_window7_224':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_tiny_patch4_window7_224.pth',
'swin_small_patch4_window7_224':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_small_patch4_window7_224.pth',
'swin_base_patch4_window7_224':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_base_patch4_window7_224_22kto1k.pth',
'swin_large_patch4_window7_224':