bugfix_swintiny (#79)

1. add easycv_model default_path
2. bugfix pretrained
pull/86/head
Chen Jiayu 2022-06-06 12:03:05 +08:00 committed by GitHub
parent 5bc5cb6f50
commit 12e3bed42b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
62 changed files with 113 additions and 112 deletions

View File

@ -6,7 +6,6 @@ model = dict(
type='PytorchImageModelWrapper',
model_name='swin_tiny_patch4_window7_224',
num_classes=10,
pretrained=True,
),
head=dict(type='ClsHead', with_fc=False))
# dataset settings

View File

@ -5,5 +5,4 @@ model = dict(
type='PytorchImageModelWrapper',
model_name='swin_base_patch4_window7_224',
num_classes=1000,
pretrained=False,
))

View File

@ -5,5 +5,4 @@ model = dict(
type='PytorchImageModelWrapper',
model_name='swin_large_patch4_window7_224',
num_classes=1000,
pretrained=False,
))

View File

@ -5,5 +5,4 @@ model = dict(
type='PytorchImageModelWrapper',
model_name='swin_tiny_patch4_window7_224',
num_classes=1000,
pretrained=False,
))

View File

@ -9,6 +9,7 @@ log_config = dict(
model = dict(
type='Classification',
train_preprocess=['mixUp'],
pretrained=False,
mixup_cfg=dict(
mixup_alpha=0.8,
cutmix_alpha=1.0,
@ -20,7 +21,6 @@ model = dict(
type='PytorchImageModelWrapper',
model_name='swin_tiny_patch4_window7_224',
num_classes=1000,
pretrained=False,
),
head=dict(
type='ClsHead',

View File

@ -9,7 +9,6 @@ log_config = dict(
model = dict(
type='Classification',
train_preprocess=['mixUp'],
pretrained=True,
mixup_cfg=dict(
mixup_alpha=0.2,
prob=1.0,

View File

@ -5,5 +5,4 @@ model = dict(
type='PytorchImageModelWrapper',
model_name='vit_base_patch16_224',
num_classes=1000,
pretrained=False,
))

View File

@ -5,5 +5,4 @@ model = dict(
type='PytorchImageModelWrapper',
model_name='vit_base_patch32_224',
num_classes=1000,
pretrained=False,
))

View File

@ -5,5 +5,4 @@ model = dict(
type='PytorchImageModelWrapper',
model_name='vit_large_patch16_224',
num_classes=1000,
pretrained=False,
))

View File

@ -5,5 +5,4 @@ model = dict(
type='PytorchImageModelWrapper',
model_name='vit_large_patch32_224',
num_classes=1000,
pretrained=False,
))

View File

@ -9,6 +9,7 @@ log_config = dict(
model = dict(
type='Classification',
train_preprocess=['mixUp'],
pretrained=False,
mixup_cfg=dict(
mixup_alpha=0.2,
prob=1.0,
@ -19,7 +20,6 @@ model = dict(
type='PytorchImageModelWrapper',
model_name='vit_base_patch16_224',
num_classes=1000,
pretrained=False,
),
head=dict(
type='ClsHead',

View File

@ -29,7 +29,6 @@ model = dict(
# model_name = 'vit_deit_small_distilled_patch16_224', # good 384,
model_name='resnet50',
num_classes=0,
pretrained=True,
),
head=dict(
type='ClsHead', with_avg_pool=True, in_channels=2048,

View File

@ -21,7 +21,7 @@ model_output_dim = 65536
model = dict(
type='DINO',
pretrained=None,
pretrained=False,
train_preprocess=[
'randomGrayScale', 'gaussianBlur', 'solarize'
], # 2+6 view, has different augment pipeline, dino is complex
@ -42,7 +42,6 @@ model = dict(
# model_name = 'vit_deit_small_distilled_patch16_224', # good 384,
model_name='resnet50',
num_classes=0,
# pretrained=True,
),
# swav need mulit crop ,doesn't support vit based model

View File

@ -15,7 +15,6 @@ model_output_dim = 65536
model = dict(
type='DINO',
pretrained=None,
train_preprocess=[
'randomGrayScale', 'gaussianBlur', 'solarize'
], # 2+6 view, has different augment pipeline, dino is complex
@ -41,7 +40,6 @@ model = dict(
# model_name = 'resnet18',
# model_name = 'resnet34',
# model_name = 'resnet101',
pretrained=True,
),
# swav need mulit crop ,doesn't support vit based model

View File

@ -20,7 +20,7 @@ model_output_dim = 65536
model = dict(
type='DINO',
pretrained=None,
pretrained=False,
train_preprocess=[
'randomGrayScale', 'gaussianBlur', 'solarize'
], # 2+6 view, has different augment pipeline, dino is complex
@ -46,7 +46,6 @@ model = dict(
# model_name = 'resnet34',
# model_name = 'resnet101',
# num_classes=0,
# pretrained=True,
),
# swav need mulit crop ,doesn't support vit based model

View File

@ -67,7 +67,6 @@ model = dict(
# model_name = 'resnet34',
# model_name = 'resnet101',
num_classes='${backbone_channel}',
pretrained=True,
),
neck=dict(
type='RetrivalNeck',

View File

@ -72,7 +72,6 @@ model = dict(
# model_name = 'resnet34',
# model_name = 'resnet101',
num_classes='${backbone_channel}',
pretrained=True,
),
neck=dict(
type='RetrivalNeck',

View File

@ -14,7 +14,7 @@ model = dict(
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=4096,
momentum=0.99,
pretrained=None,
pretrained=False,
# backbone=dict(
# type='ResNet',
# depth=50,
@ -40,7 +40,6 @@ model = dict(
# model_name='shuffletrans_tiny_p4_w7_224', #768
model_name='resnet50', # 2048
num_classes=0,
pretrained=False,
),
neck=dict(
type='MoBYMLP',

View File

@ -3,7 +3,7 @@ _base_ = 'configs/base.py'
model = dict(
type='MOCO',
pretrained=None,
pretrained=False,
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=65536,
feat_dim=128,
@ -24,7 +24,6 @@ model = dict(
# #model_name = 'vit_deit_small_distilled_patch16_224', # good 384,
# #model_name = 'resnet50',
# #num_classes=1000,
# pretrained=True,
# ),
neck=dict(
type='NonLinearNeckV1',

View File

@ -13,7 +13,7 @@ work_dir = 'oss://path/to/work_dirs/moco_r50_oss/'
model = dict(
type='MOCO',
pretrained=None,
pretrained=False,
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=65536,
feat_dim=128,
@ -34,7 +34,6 @@ model = dict(
# #model_name = 'vit_deit_small_distilled_patch16_224', # good 384,
# #model_name = 'resnet50',
# #num_classes=1000,
# pretrained=True,
# ),
neck=dict(
type='NonLinearNeckV1',

View File

@ -3,7 +3,6 @@ _base_ = 'configs/base.py'
model = dict(
type='MOCO',
pretrained=None,
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=65536,
feat_dim=128,
@ -21,7 +20,6 @@ model = dict(
# model_name='xcit_medium_24_p8', #384
model_name='xcit_large_24_p8', # 768
num_classes=0,
pretrained=True,
),
neck=dict(
type='NonLinearNeckV1',

View File

@ -13,7 +13,6 @@ work_dir = 'oss://path/to/work_dirs/moco_timm_oss/'
model = dict(
type='MOCO',
pretrained=None,
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=65536,
feat_dim=128,
@ -29,7 +28,6 @@ model = dict(
# model_name = 'resnet50',
model_name='xcit_small_12_p8',
num_classes=0,
pretrained=True,
),
neck=dict(
type='NonLinearNeckV1',

View File

@ -5,7 +5,7 @@ num_crops = [2, 6]
# model settings
model = dict(
type='SWAV',
pretrained=None,
pretrained=False,
train_preprocess=['randomGrayScale', 'gaussianBlur'],
backbone=dict(
type='ResNet',

View File

@ -15,7 +15,7 @@ num_crops = [2, 6]
model = dict(
type='SWAV',
pretrained=None,
pretrained=False,
train_preprocess=['randomGrayScale', 'gaussianBlur'],
backbone=dict(
type='ResNet',

View File

@ -27,7 +27,7 @@ channel_cfg = dict(
# model settings
model = dict(
type='TopDown',
pretrained=None,
pretrained=False,
backbone=dict(
type='LiteHRNet',
in_channels=3,

View File

@ -5,7 +5,7 @@ model = dict(
type='Classification',
train_preprocess=['randomErasing'],
# train_preprocess=['mixUp'],
pretrained=None,
pretrained=False,
backbone=dict(
type='ResNet',
depth=50,

View File

@ -18,9 +18,7 @@ model = dict(
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/swin_base_patch4_window7_224_22k_statedict.pth',
backbone=dict(
type='PytorchImageModelWrapper',
model_name='swin_base_patch4_window7_224_in22k'
# pretrained=False,
),
model_name='swin_base_patch4_window7_224_in22k'),
neck=dict(
type='RetrivalNeck',
in_channels=1024,

View File

@ -40,7 +40,7 @@ channel_cfg = dict(
# model settings
model = dict(
type='TopDown',
pretrained=None,
pretrained=False,
backbone=dict(
type='LiteHRNet',
in_channels=3,

View File

@ -3,7 +3,7 @@ _base_ = '../../base.py'
# model settings
model = dict(
type='BYOL',
pretrained=None,
pretrained=False,
base_momentum=0.996,
backbone=dict(
type='ResNet',

View File

@ -7,7 +7,7 @@ model_output_dim = 65536
model = dict(
type='DINO',
pretrained=None,
pretrained=False,
train_preprocess=[
'randomGrayScale', 'gaussianBlur', 'solarize'
], # 2+6 view, has different augment pipeline, dino is complex
@ -33,7 +33,6 @@ model = dict(
# model_name = 'resnet18',
# model_name = 'resnet34',
# model_name = 'resnet101',
# pretrained=True,
),
# swav need mulit crop ,doesn't support vit based model

View File

@ -9,7 +9,7 @@ model_output_dim = 65536
model = dict(
type='DINO',
pretrained=None,
pretrained=False,
train_preprocess=[
'randomGrayScale', 'gaussianBlur', 'solarize'
], # 2+6 view, has different augment pipeline, dino is complex
@ -35,7 +35,6 @@ model = dict(
# model_name = 'resnet34',
# model_name = 'resnet101',
# num_classes=0,
# pretrained=True,
),
# swav need mulit crop ,doesn't support vit based model

View File

@ -12,7 +12,7 @@ _base_ = '../../base.py'
# model settings
model = dict(
type='MIXCO',
pretrained=None,
pretrained=False,
train_preprocess=['randomGrayScale', 'gaussianBlur', 'mixUp'],
queue_len=65536,
feat_dim=128,

View File

@ -12,7 +12,7 @@ _base_ = '../../base.py'
# model settings
model = dict(
type='MIXCO',
pretrained=None,
pretrained=False,
train_preprocess=['randomGrayScale', 'gaussianBlur', 'mixUp'],
queue_len=65536,
feat_dim=128,

View File

@ -6,7 +6,7 @@ model = dict(
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=4096,
momentum=0.99,
pretrained=None,
pretrained=False,
backbone=dict(
type='PytorchImageModelWrapper',
# model_name='pit_xs_distilled_224',
@ -16,7 +16,6 @@ model = dict(
# model_name='vit_deit_tiny_distilled_patch16_224', # good 192,
model_name='vit_deit_small_distilled_patch16_224', # good 384,
# model_name = 'resnet50',
# pretrained=True,
),
neck=dict(
type='MoBYMLP',

View File

@ -15,7 +15,7 @@ model = dict(
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=4096,
momentum=0.99,
pretrained=None,
pretrained=False,
backbone=dict(
type='PytorchImageModelWrapper',
# model_name='pit_xs_distilled_224',
@ -30,7 +30,6 @@ model = dict(
# model_name='shuffletrans_tiny_p4_w7_224', #768
# model_name = 'resnet50', # 2048
num_classes=0,
pretrained=False,
),
neck=dict(
type='MoBYMLP',

View File

@ -15,12 +15,11 @@ model = dict(
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=4096,
momentum=0.99,
pretrained=None,
pretrained=False,
backbone=dict(
type='PytorchImageModelWrapper',
model_name='dynamic_swin_tiny_p4_w7_224',
num_classes=0,
pretrained=False,
),
neck=dict(
type='MoBYMLP',

View File

@ -6,7 +6,6 @@ model = dict(
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=4096,
momentum=0.99,
pretrained=None,
# backbone=dict(
# type='ResNet',
# depth=50,
@ -31,7 +30,6 @@ model = dict(
# model_name='shuffletrans_tiny_p4_w7_224', #768
model_name='resnet50', # 2048
num_classes=0,
pretrained=True,
),
neck=dict(
type='MoBYMLP',

View File

@ -3,7 +3,7 @@ _base_ = '../../base.py'
# model settings
model = dict(
type='MOCO',
pretrained=None,
pretrained=False,
train_preprocess=['gaussianBlur'],
queue_len=65536,
feat_dim=128,

View File

@ -3,7 +3,7 @@ _base_ = '../../base.py'
# model settings
model = dict(
type='MOCO',
pretrained=None,
pretrained=False,
train_preprocess=['randomGrayScale', 'gaussianBlur'],
queue_len=65536,
feat_dim=128,
@ -24,7 +24,6 @@ model = dict(
# #model_name = 'vit_deit_small_distilled_patch16_224', # good 384,
# #model_name = 'resnet50',
# #num_classes=1000,
# pretrained=True,
# ),
neck=dict(
type='NonLinearNeckV1',

View File

@ -3,7 +3,7 @@ _base_ = '../../base.py'
# model settings
model = dict(
type='SimCLR',
pretrained=None,
pretrained=False,
backbone=dict(
type='ResNet',
depth=50,

View File

@ -6,7 +6,7 @@ work_dir = 'work_dir/simclr/'
# model settings
model = dict(
type='SimCLR',
pretrained=None,
pretrained=False,
train_preprocess=['randomGrayScale', 'gaussianBlur'],
backbone=dict(
type='ResNet',

View File

@ -3,7 +3,7 @@ _base_ = '../../base.py'
# model settings
model = dict(
type='SimCLR',
pretrained=None,
pretrained=False,
backbone=dict(
type='ResNet',
depth=50,

View File

@ -6,7 +6,7 @@ num_crops = [2, 6]
# model settings
model = dict(
type='SWAV',
pretrained=None,
pretrained=False,
train_preprocess=['randomGrayScale', 'gaussianBlur'],
backbone=dict(type='PlainNet', plainnet_struct_idx='normal'),
neck=dict(

View File

@ -5,7 +5,7 @@ num_crops = [2, 6]
model = dict(
type='SWAV',
pretrained=None,
pretrained=False,
train_preprocess=['randomGrayScale', 'gaussianBlur'],
backbone=dict(
type='ResNet',

View File

@ -9,8 +9,6 @@ from mmcv.cnn import constant_init, kaiming_init
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..modelzoo import bninception as model_urls
from ..registry import BACKBONES
@ -348,6 +346,9 @@ class BNInception(nn.Module):
if num_classes > 0:
self.last_linear = nn.Linear(1024, num_classes)
self.default_pretrained_model_path = model_urls[
self.__class__.__name__]
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):

View File

@ -8,8 +8,6 @@ from mmcv.cnn import constant_init, kaiming_init
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..modelzoo import genet as model_urls
from ..registry import BACKBONES
@ -1660,6 +1658,9 @@ class PlainNet(nn.Module):
self.plainnet_struct = str(self) + str(self.adptive_avg_pool)
self.zero_init_residual = False
self.default_pretrained_model_path = model_urls[self.__class__.__name__
+ plainnet_struct_idx]
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):

View File

@ -8,8 +8,7 @@ from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
from torch.nn.modules.batchnorm import _BatchNorm
from easycv.models.registry import BACKBONES
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..modelzoo import hrnet as model_urls
from .resnet import BasicBlock
@ -592,6 +591,9 @@ class HRNet(nn.Module):
multiscale_output=self.stage4_cfg.get('multiscale_output',
multi_scale_output))
self.default_pretrained_model_path = model_urls.get(
self.__class__.__name__ + arch, None)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
@ -718,12 +720,6 @@ class HRNet(nn.Module):
return nn.Sequential(*hr_modules), in_channels
def init_weights(self):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001)

View File

@ -11,8 +11,6 @@ import torch.nn.functional as F
from mmcv.cnn import constant_init, kaiming_init
from torch.nn.modules.batchnorm import _BatchNorm
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..modelzoo import inceptionv3 as model_urls
from ..registry import BACKBONES
@ -60,6 +58,9 @@ class Inception3(nn.Module):
if num_classes > 0:
self.fc = nn.Linear(2048, num_classes)
self.default_pretrained_model_path = model_urls[
self.__class__.__name__]
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):

View File

@ -12,8 +12,6 @@ from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
from torch.nn.modules.batchnorm import _BatchNorm
from easycv.models.registry import BACKBONES
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
def channel_shuffle(x, groups):

View File

@ -6,8 +6,6 @@ r""" This model is taken from the official PyTorch model zoo.
import torch
from torch import nn
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..modelzoo import mnasnet as model_urls
from ..registry import BACKBONES
@ -144,6 +142,9 @@ class MNASNet(torch.nn.Module):
nn.Dropout(p=dropout, inplace=True),
nn.Linear(1280, num_classes))
self.default_pretrained_model_path = model_urls[self.__class__.__name__
+ str(alpha)]
def forward(self, x):
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.

View File

@ -5,8 +5,6 @@ r""" This model is taken from the official PyTorch model zoo.
from torch import nn
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..modelzoo import mobilenetv2 as model_urls
from ..registry import BACKBONES
@ -153,8 +151,9 @@ class MobileNetV2(nn.Module):
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)
self.pretrained = model_urls[self.__class__.__name__ + '_' +
str(width_multi)]
self.default_pretrained_model_path = model_urls[self.__class__.__name__
+ '_' +
str(width_multi)]
def init_weights(self):
for m in self.modules():

View File

@ -8,7 +8,6 @@ import torch.nn as nn
from timm.models.helpers import load_pretrained
from timm.models.hub import download_cached_file
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger, print_log
from ..modelzoo import timm_models as model_urls
from ..registry import BACKBONES
@ -113,35 +112,37 @@ class PytorchImageModelWrapper(nn.Module):
logger = get_root_logger()
if pretrained:
if self.model_name in self.timm_model_names:
pretrained_path = model_urls[self.model_name]
default_pretrained_model_path = model_urls[self.model_name]
print_log(
'load model from default path: {}'.format(pretrained_path),
logger)
if pretrained_path.endswith('.npz'):
'load model from default path: {}'.format(
default_pretrained_model_path), logger)
if default_pretrained_model_path.endswith('.npz'):
pretrained_loc = download_cached_file(
pretrained_path, check_hash=False, progress=False)
default_pretrained_model_path,
check_hash=False,
progress=False)
return self.model.load_pretrained(pretrained_loc)
else:
backbone_module = importlib.import_module(
self.model.__module__)
return load_pretrained(
self.model,
default_cfg={'url': pretrained_path},
default_cfg={'url': default_pretrained_model_path},
filter_fn=backbone_module.checkpoint_filter_fn
if hasattr(backbone_module, 'checkpoint_filter_fn')
else None)
elif self.model_name in _MODEL_MAP:
if self.model_name in model_urls.keys():
pretrained_path = model_urls[self.model_name]
default_pretrained_model_path = model_urls[self.model_name]
print_log(
'load model from default path: {}'.format(
pretrained_path), logger)
default_pretrained_model_path), logger)
try_max = 3
try_idx = 0
while try_idx < try_max:
try:
state_dict = torch.hub.load_state_dict_from_url(
url=pretrained_path,
url=default_pretrained_model_path,
map_location='cpu',
)
try_idx += try_max
@ -164,6 +165,7 @@ class PytorchImageModelWrapper(nn.Module):
'Error: Fail to create {} with (pretrained={}...)'.format(
self.model_name, pretrained))
else:
print_log('load model from init weights')
self.model.init_weights()
def forward(self, x):

View File

@ -14,8 +14,6 @@ import torch.nn.functional as F
from torch.nn import Conv2d, Module, ReLU
from torch.nn.modules.utils import _pair
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..registry import BACKBONES

View File

@ -4,8 +4,6 @@ import torch.utils.checkpoint as cp
from mmcv.cnn import constant_init, kaiming_init
from torch.nn.modules.batchnorm import _BatchNorm
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..modelzoo import resnet as model_urls
from ..registry import BACKBONES
from ..utils import FReLU, build_conv_layer, build_norm_layer
@ -443,6 +441,9 @@ class ResNet(nn.Module):
if num_classes > 0:
self.fc = nn.Linear(self.feat_dim, num_classes)
self.default_pretrained_model_path = model_urls.get(
self.__class__.__name__ + str(depth), None)
@property
def norm1(self):
return getattr(self, self.norm1_name)

View File

@ -6,8 +6,6 @@ import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init
from torch.nn.modules.batchnorm import _BatchNorm
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer

View File

@ -3,6 +3,7 @@ import math
import torch.nn as nn
from ..modelzoo import resnext as model_urls
from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer
from .resnet import Bottleneck as _Bottleneck
@ -172,6 +173,7 @@ class ResNeXt(ResNet):
self.inplanes = 64
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = self.strides[i]
dilation = self.dilations[i]
@ -196,3 +198,7 @@ class ResNeXt(ResNet):
self.res_layers.append(layer_name)
self._freeze_stages()
self.default_pretrained_model_path = model_urls.get(
self.__class__.__name__ + str(self.depth) + '-' +
str(self.groups) + 'x' + str(self.base_width) + 'd', None)

View File

@ -7,8 +7,6 @@ from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_
from torch import nn
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..registry import BACKBONES

View File

@ -16,8 +16,6 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..registry import BACKBONES

View File

@ -15,9 +15,6 @@ import torch.nn as nn
# from utils import trunc_normal_
from timm.models.layers import trunc_normal_
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:

View File

@ -19,8 +19,6 @@ import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.vision_transformer import Mlp, _cfg
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..registry import BACKBONES

View File

@ -33,7 +33,7 @@ class Classification(BaseModel):
with_sobel=False,
head=None,
neck=None,
pretrained=None,
pretrained=True,
mixup_cfg=None):
super(Classification, self).__init__()
self.with_sobel = with_sobel
@ -110,16 +110,29 @@ class Classification(BaseModel):
self.init_weights()
def init_weights(self):
logger = get_root_logger()
if isinstance(self.pretrained, str):
logger = get_root_logger()
load_checkpoint(
self.backbone, self.pretrained, strict=False, logger=logger)
else:
print_log('load model from init weights')
elif self.pretrained:
if self.backbone.__class__.__name__ == 'PytorchImageModelWrapper':
self.backbone.init_weights(pretrained=self.pretrained)
elif hasattr(self.backbone, 'default_pretrained_model_path'
) and self.backbone.default_pretrained_model_path:
print_log(
'load model from default path: {}'.format(
self.backbone.default_pretrained_model_path), logger)
load_checkpoint(
self.backbone,
self.backbone.default_pretrained_model_path,
strict=False,
logger=logger)
else:
print_log('load model from init weights')
self.backbone.init_weights()
else:
print_log('load model from init weights')
self.backbone.init_weights()
for idx in range(self.head_num):
h = getattr(self, 'head_%d' % idx)

View File

@ -12,6 +12,34 @@ resnet = {
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet152.pth',
}
resnext = {
'ResNeXt50-32x4d':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth',
'ResNeXt101-32x4d':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext101-32x4d/epoch_100.pth',
'ResNeXt101-32x8d':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext101-32x8d/epoch_100.pth',
'ResNext152-32x4d':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext152-32x4d/epoch_100.pth',
}
hrnet = {
'HRNetw18':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw18/epoch_100.pth',
'HRNetw30':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw30/epoch_100.pth',
'HRNetw32':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw32/epoch_100.pth',
'HRNetw40':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw40/epoch_100.pth',
'HRNetw44':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw44/epoch_100.pth',
'HRNetw48':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw48/epoch_100.pth',
'HRNetw64':
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw64/epoch_100.pth',
}
mobilenetv2 = {
'MobileNetV2_1.0':
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/mobilenetv2/mobilenet_v2.pth',