mirror of https://github.com/alibaba/EasyCV.git
Swint_cifar config bugfix (#95)
parent
b737027aa4
commit
0f69dbe902
|
@ -5,9 +5,9 @@ model = dict(
|
|||
backbone=dict(
|
||||
type='PytorchImageModelWrapper',
|
||||
model_name='swin_tiny_patch4_window7_224',
|
||||
num_classes=10,
|
||||
num_classes=0,
|
||||
),
|
||||
head=dict(type='ClsHead', with_fc=False))
|
||||
head=dict(type='ClsHead', in_channels=768, with_fc=True, num_classes=10))
|
||||
# dataset settings
|
||||
class_list = [
|
||||
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
|
||||
|
|
|
@ -111,25 +111,29 @@ class PytorchImageModelWrapper(nn.Module):
|
|||
logger = get_root_logger()
|
||||
if pretrained:
|
||||
if self.model_name in self.timm_model_names:
|
||||
default_pretrained_model_path = model_urls[self.model_name]
|
||||
print_log(
|
||||
'load model from default path: {}'.format(
|
||||
default_pretrained_model_path), logger)
|
||||
if default_pretrained_model_path.endswith('.npz'):
|
||||
pretrained_loc = download_cached_file(
|
||||
default_pretrained_model_path,
|
||||
check_hash=False,
|
||||
progress=False)
|
||||
return self.model.load_pretrained(pretrained_loc)
|
||||
if self.model_name in model_urls:
|
||||
default_pretrained_model_path = model_urls[self.model_name]
|
||||
print_log(
|
||||
'load model from default path: {}'.format(
|
||||
default_pretrained_model_path), logger)
|
||||
if default_pretrained_model_path.endswith('.npz'):
|
||||
pretrained_loc = download_cached_file(
|
||||
default_pretrained_model_path,
|
||||
check_hash=False,
|
||||
progress=False)
|
||||
return self.model.load_pretrained(pretrained_loc)
|
||||
else:
|
||||
backbone_module = importlib.import_module(
|
||||
self.model.__module__)
|
||||
return load_pretrained(
|
||||
self.model,
|
||||
default_cfg={'url': default_pretrained_model_path},
|
||||
filter_fn=backbone_module.checkpoint_filter_fn
|
||||
if hasattr(backbone_module, 'checkpoint_filter_fn')
|
||||
else None,
|
||||
strict=False)
|
||||
else:
|
||||
backbone_module = importlib.import_module(
|
||||
self.model.__module__)
|
||||
return load_pretrained(
|
||||
self.model,
|
||||
default_cfg={'url': default_pretrained_model_path},
|
||||
filter_fn=backbone_module.checkpoint_filter_fn
|
||||
if hasattr(backbone_module, 'checkpoint_filter_fn')
|
||||
else None)
|
||||
logger.warning('pretrained model for model_name not found')
|
||||
elif self.model_name in _MODEL_MAP:
|
||||
if self.model_name in model_urls.keys():
|
||||
default_pretrained_model_path = model_urls[self.model_name]
|
||||
|
|
|
@ -94,6 +94,10 @@ timm_models = {
|
|||
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_patch16_224-b5f2ef4d.pth',
|
||||
'deit_base_distilled_patch16_224':
|
||||
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
||||
'swin_tiny_patch4_window7_224':
|
||||
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_tiny_patch4_window7_224.pth',
|
||||
'swin_small_patch4_window7_224':
|
||||
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_small_patch4_window7_224.pth',
|
||||
'swin_base_patch4_window7_224':
|
||||
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_base_patch4_window7_224_22kto1k.pth',
|
||||
'swin_large_patch4_window7_224':
|
||||
|
|
Loading…
Reference in New Issue