Swint_cifar config bugfix (#95)

pull/97/head
Chen Jiayu 2022-06-16 14:46:47 +08:00 committed by GitHub
parent b737027aa4
commit 0f69dbe902
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 20 deletions

View File

@ -5,9 +5,9 @@ model = dict(
backbone=dict( backbone=dict(
type='PytorchImageModelWrapper', type='PytorchImageModelWrapper',
model_name='swin_tiny_patch4_window7_224', model_name='swin_tiny_patch4_window7_224',
num_classes=10, num_classes=0,
), ),
head=dict(type='ClsHead', with_fc=False)) head=dict(type='ClsHead', in_channels=768, with_fc=True, num_classes=10))
# dataset settings # dataset settings
class_list = [ class_list = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',

View File

@ -111,6 +111,7 @@ class PytorchImageModelWrapper(nn.Module):
logger = get_root_logger() logger = get_root_logger()
if pretrained: if pretrained:
if self.model_name in self.timm_model_names: if self.model_name in self.timm_model_names:
if self.model_name in model_urls:
default_pretrained_model_path = model_urls[self.model_name] default_pretrained_model_path = model_urls[self.model_name]
print_log( print_log(
'load model from default path: {}'.format( 'load model from default path: {}'.format(
@ -129,7 +130,10 @@ class PytorchImageModelWrapper(nn.Module):
default_cfg={'url': default_pretrained_model_path}, default_cfg={'url': default_pretrained_model_path},
filter_fn=backbone_module.checkpoint_filter_fn filter_fn=backbone_module.checkpoint_filter_fn
if hasattr(backbone_module, 'checkpoint_filter_fn') if hasattr(backbone_module, 'checkpoint_filter_fn')
else None) else None,
strict=False)
else:
logger.warning('pretrained model for model_name not found')
elif self.model_name in _MODEL_MAP: elif self.model_name in _MODEL_MAP:
if self.model_name in model_urls.keys(): if self.model_name in model_urls.keys():
default_pretrained_model_path = model_urls[self.model_name] default_pretrained_model_path = model_urls[self.model_name]

View File

@ -94,6 +94,10 @@ timm_models = {
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_patch16_224-b5f2ef4d.pth', 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_patch16_224-b5f2ef4d.pth',
'deit_base_distilled_patch16_224': 'deit_base_distilled_patch16_224':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_distilled_patch16_224-df68dfff.pth', 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_distilled_patch16_224-df68dfff.pth',
'swin_tiny_patch4_window7_224':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_tiny_patch4_window7_224.pth',
'swin_small_patch4_window7_224':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_small_patch4_window7_224.pth',
'swin_base_patch4_window7_224': 'swin_base_patch4_window7_224':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_base_patch4_window7_224_22kto1k.pth', 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_base_patch4_window7_224_22kto1k.pth',
'swin_large_patch4_window7_224': 'swin_large_patch4_window7_224':