694 lines
18 KiB
Python
694 lines
18 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
from typing import List
|
|
from typing import Dict, Callable
|
|
from mmrazor.registry import MODELS
|
|
from mmengine.config import Config
|
|
import os
|
|
from mmengine.utils import get_installed_path
|
|
from mmrazor.registry import MODELS
|
|
import torch
|
|
import torch.nn as nn
|
|
from .models import (AddCatModel, ConcatModel, ConvAttnModel, DwConvModel,
|
|
ExpandLineModel, GroupWiseConvModel, SingleLineModel,
|
|
MultiBindModel, MultiConcatModel, MultiConcatModel2,
|
|
ResBlock, Xmodel, MultipleUseModel, Icep, SelfAttention)
|
|
import json
|
|
# model generator
|
|
from mmdet.testing._utils import demo_mm_inputs
|
|
import string
|
|
import copy
|
|
# helper functions
|
|
|
|
|
|
def get_shape(tensor, only_length=False):
|
|
if isinstance(tensor, torch.Tensor):
|
|
if only_length:
|
|
return len(tensor.shape)
|
|
else:
|
|
return tensor.shape
|
|
elif isinstance(tensor, list) or isinstance(tensor, tuple):
|
|
shapes = []
|
|
for x in tensor:
|
|
shapes.append(get_shape(x, only_length))
|
|
return shapes
|
|
elif isinstance(tensor, dict):
|
|
shapes = {}
|
|
for key in tensor:
|
|
shapes[key] = get_shape(tensor[key], only_length)
|
|
return shapes
|
|
else:
|
|
raise NotImplementedError(
|
|
f'unsuppored type{type(tensor)} to get shape of tensors.')
|
|
|
|
|
|
# generators
|
|
|
|
|
|
class ModelGenerator(nn.Module):
|
|
|
|
def __init__(self, name: str, model_src) -> None:
|
|
super().__init__()
|
|
self.name = name
|
|
self.model_src = model_src
|
|
self._model = None
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.init_model()
|
|
|
|
def init_model(self):
|
|
return self.model_src()
|
|
|
|
def forward(self, x):
|
|
assert self._model is not None
|
|
return self._model(x, *self.input())
|
|
|
|
def input(self):
|
|
return []
|
|
|
|
def assert_model_is_changed(self, tensors_org, tensors_new):
|
|
shape1 = get_shape(tensors_org)
|
|
shape2 = get_shape(tensors_new)
|
|
assert shape1 == shape2, f'{shape1}!={shape2}'
|
|
|
|
def __repr__(self) -> str:
|
|
return self.name
|
|
|
|
@classmethod
|
|
def get_base_name(cls, name: str):
|
|
names = name.split('.')
|
|
return '.'.join(names[1:])
|
|
|
|
@classmethod
|
|
def get_short_name(cls, name: str):
|
|
scope = name.split('.')[0]
|
|
base_name = cls.get_base_name(name)
|
|
names = base_name.replace('-', '.').replace('_', '.').split('.')
|
|
name = names[0]
|
|
name = name.rstrip(string.digits)
|
|
|
|
return f'{scope}.{name}'
|
|
|
|
@property
|
|
def base_name(self):
|
|
return self.__class__.get_base_name(self.name)
|
|
|
|
@property
|
|
def short_name(self):
|
|
return self.__class__.get_short_name(self.name)
|
|
|
|
@property
|
|
def scope(self):
|
|
return self.name.split('.')[0]
|
|
|
|
|
|
class MMModelGenerator(ModelGenerator):
|
|
|
|
def __init__(self, name, cfg) -> None:
|
|
self.cfg = cfg
|
|
super().__init__(name, self.get_model_src)
|
|
|
|
def get_model_src(self):
|
|
model = MODELS.build(self.cfg)
|
|
model = revert_sync_batchnorm(model)
|
|
return model
|
|
|
|
def __repr__(self) -> str:
|
|
return self.name
|
|
|
|
|
|
class MMDetModelGenerator(MMModelGenerator):
|
|
|
|
def forward(self, x):
|
|
assert self._model is not None
|
|
self._model.eval()
|
|
return self._model(x, **self.input(), mode='tensor')
|
|
|
|
def input(self):
|
|
data = demo_mm_inputs(1, [[3, 224, 224]])
|
|
data = self._model.data_preprocessor(data, False)
|
|
data.pop('inputs')
|
|
return data
|
|
|
|
def assert_model_is_changed(self, tensors_org, tensors_new):
|
|
assert get_shape(tensors_org, True) == get_shape(tensors_new, True)
|
|
|
|
|
|
# model library
|
|
|
|
|
|
class ModelLibrary:
|
|
default_includes: List = []
|
|
_models = None
|
|
|
|
def __init__(self, include=default_includes, exclude=[]) -> None:
|
|
self.include_key = include
|
|
self.exclude_key = exclude
|
|
self._include_models, self._uninclude_models, self.exclude_models =\
|
|
self._classify_models(self.models)
|
|
|
|
@property
|
|
def models(self):
|
|
if self.__class__._models is None:
|
|
self.__class__._models: Dict[
|
|
str, Callable] = self.__class__.get_models()
|
|
return self.__class__._models
|
|
|
|
@classmethod
|
|
def get_models(cls):
|
|
raise NotImplementedError()
|
|
|
|
def include_models(self):
|
|
return self._include_models
|
|
|
|
def uninclude_models(self):
|
|
return self._uninclude_models
|
|
|
|
def is_include(self, name: str, includes: List[str], start_with=True):
|
|
for key in includes:
|
|
if start_with:
|
|
if name.startswith(key):
|
|
return True
|
|
else:
|
|
if key in name:
|
|
return True
|
|
return False
|
|
|
|
def is_default_includes_cover_all_models(self):
|
|
models = copy.copy(self._models)
|
|
is_covered = True
|
|
for name in models:
|
|
if self.is_include(name, self.__class__.default_includes):
|
|
pass
|
|
else:
|
|
is_covered = False
|
|
print(name, '\tnot include')
|
|
return is_covered
|
|
|
|
def short_names(self):
|
|
short_names = set()
|
|
for name in self.models:
|
|
short_names.add(self.models[name].short_name)
|
|
return short_names
|
|
|
|
def _classify_models(self, models: Dict):
|
|
include = []
|
|
uninclude = []
|
|
exclude = []
|
|
for name in models:
|
|
if self.is_include(name, self.exclude_key, start_with=False):
|
|
exclude.append(models[name])
|
|
elif self.is_include(name, self.include_key, start_with=True):
|
|
include.append(models[name])
|
|
else:
|
|
uninclude.append(models[name])
|
|
return include, uninclude, exclude
|
|
|
|
def get_short_name_of_model(self, name: str):
|
|
names = name.replace('-', '.').replace('_', '.').split('.')
|
|
return names[0]
|
|
|
|
|
|
class DefaultModelLibrary(ModelLibrary):
|
|
_mm_models = None
|
|
|
|
default_includes: List = [
|
|
'SingleLineModel',
|
|
'ResBlock',
|
|
'AddCatModel',
|
|
'ConcatModel',
|
|
'MultiConcatModel',
|
|
'MultiConcatModel2',
|
|
'GroupWiseConvModel',
|
|
'Xmodel',
|
|
'MultipleUseModel',
|
|
'Icep',
|
|
'ExpandLineModel',
|
|
'MultiBindModel',
|
|
'DwConvModel',
|
|
'ConvAttnModel',
|
|
'SelfAttention',
|
|
# mm models
|
|
'resnet',
|
|
'pspnet',
|
|
'yolo'
|
|
]
|
|
|
|
def __init__(self,
|
|
include=default_includes,
|
|
exclude=[],
|
|
with_mm_models=False) -> None:
|
|
self.with_mm_models = with_mm_models
|
|
super().__init__(include, exclude)
|
|
|
|
@property
|
|
def models(self):
|
|
models = copy.copy(super().models)
|
|
if self.with_mm_models:
|
|
models.update(self.mm_models)
|
|
return models
|
|
|
|
@property
|
|
def mm_models(self):
|
|
if self.__class__._mm_models is None:
|
|
self.__class__._mm_models = self.get_mm_models()
|
|
return self.__class__._mm_models
|
|
|
|
@classmethod
|
|
def get_models(cls):
|
|
models = [
|
|
SingleLineModel,
|
|
ResBlock,
|
|
AddCatModel,
|
|
ConcatModel,
|
|
MultiConcatModel,
|
|
MultiConcatModel2,
|
|
GroupWiseConvModel,
|
|
Xmodel,
|
|
MultipleUseModel,
|
|
Icep,
|
|
ExpandLineModel,
|
|
MultiBindModel,
|
|
DwConvModel, #
|
|
ConvAttnModel,
|
|
SelfAttention,
|
|
]
|
|
model_dict = {}
|
|
for model in models:
|
|
model_dict[model.__name__] = ModelGenerator(
|
|
'default.' + model.__name__, model)
|
|
return model_dict
|
|
|
|
@classmethod
|
|
def get_mm_models(cls):
|
|
paths = [
|
|
'mmcls::resnet/resnet34_8xb32_in1k.py',
|
|
'mmseg::pspnet/pspnet_r18-d8_4xb4-80k_potsdam-512x512.py',
|
|
'mmdet::yolo/yolov3_d53_8xb8-320-273e_coco.py'
|
|
]
|
|
models = {}
|
|
for path in paths:
|
|
Model = MMModelLibrary.get_model_from_path(path)
|
|
models[Model.base_name] = Model
|
|
return models
|
|
|
|
|
|
class TorchModelLibrary(ModelLibrary):
|
|
|
|
default_includes = [
|
|
'alexnet', 'densenet', 'efficientnet', 'googlenet', 'inception',
|
|
'mnasnet', 'mobilenet', 'regnet', 'resnet', 'resnext', 'shufflenet',
|
|
'squeezenet', 'vgg', 'wide_resnet', "vit", "swin", "convnext"
|
|
]
|
|
|
|
def __init__(self, include=default_includes, exclude=[]) -> None:
|
|
super().__init__(include, exclude)
|
|
|
|
@classmethod
|
|
def get_models(cls):
|
|
from inspect import isfunction
|
|
|
|
import torchvision
|
|
|
|
attrs = dir(torchvision.models)
|
|
models = {}
|
|
for name in attrs:
|
|
module = getattr(torchvision.models, name)
|
|
if isfunction(module) and name is not 'get_weight':
|
|
models[name] = ModelGenerator('torch.' + name, module)
|
|
return models
|
|
|
|
|
|
class MMModelLibrary(ModelLibrary):
|
|
default_includes = []
|
|
base_config_path = '/'
|
|
repo = 'mmxx'
|
|
|
|
def __init__(self, include=default_includes, exclude=[]) -> None:
|
|
super().__init__(include, exclude)
|
|
|
|
@classmethod
|
|
def scope_path(cls):
|
|
path = cls._scope_path(cls.repo) + cls.base_config_path
|
|
return path
|
|
|
|
@classmethod
|
|
def get_models(cls):
|
|
models = {}
|
|
added_models = set()
|
|
for dirpath, dirnames, filenames in os.walk(cls.scope_path()):
|
|
for filename in filenames:
|
|
if filename.endswith('.py'):
|
|
|
|
cfg_path = dirpath + '/' + filename
|
|
try:
|
|
config = Config.fromfile(cfg_path)
|
|
except:
|
|
continue
|
|
if 'model' in config:
|
|
|
|
# get model_name
|
|
model_name = cls.get_model_name_from_path(
|
|
cfg_path, cls.scope_path())
|
|
|
|
model_cfg = config['model']
|
|
model_cfg = cls._config_process(model_cfg)
|
|
if json.dumps(model_cfg) not in added_models:
|
|
models[model_name] = cls.generator_type()(
|
|
cls.repo + '.' + model_name, model_cfg)
|
|
added_models.add(json.dumps(model_cfg))
|
|
return models
|
|
|
|
@classmethod
|
|
def generator_type(cls):
|
|
return MMModelGenerator
|
|
|
|
@classmethod
|
|
def get_model_name_from_path(cls, config_path, scope_path):
|
|
import os
|
|
dirpath = os.path.dirname(config_path) + '/'
|
|
filename = os.path.basename(config_path)
|
|
|
|
model_type_name = '_'.join(dirpath.replace(scope_path, '').split('/'))
|
|
model_type_name = model_type_name if model_type_name == '' else model_type_name + '_'
|
|
model_name = model_type_name + \
|
|
os.path.basename(filename).split('.')[0]
|
|
return model_name
|
|
|
|
@classmethod
|
|
def get_model_from_path(cls, config_path):
|
|
path, scope = Config._get_cfg_path(config_path, '')
|
|
if scope is None:
|
|
scope = 'mmrazor'
|
|
config = Config.fromfile(path)['model']
|
|
config = cls._config_process(config=config)
|
|
config['_scope_'] = scope
|
|
name = cls.get_model_name_from_path(path, cls._scope_path(scope))
|
|
return cls.generator_type()(scope + '.' + name, config)
|
|
|
|
@staticmethod
|
|
def _scope_path(scope):
|
|
if scope == 'mmseg':
|
|
scope = 'mmsegmentation'
|
|
repo_path = get_installed_path(scope)
|
|
path = repo_path + '/.mim/configs/'
|
|
return path
|
|
|
|
@classmethod
|
|
def _config_process(cls, config: Dict):
|
|
config['_scope_'] = cls.repo
|
|
config = cls._remove_certain_key(config, 'init_cfg')
|
|
config = cls._remove_certain_key(config, 'pretrained')
|
|
config = cls._remove_certain_key(config, 'Pretrained')
|
|
return config
|
|
|
|
@classmethod
|
|
def _remove_certain_key(cls, config: Dict, key: str = 'init_cfg'):
|
|
if isinstance(config, dict):
|
|
if key in config:
|
|
config.pop(key)
|
|
for keyx in config:
|
|
config[keyx] = cls._remove_certain_key(config[keyx], key)
|
|
return config
|
|
|
|
|
|
class MMClsModelLibrary(MMModelLibrary):
|
|
|
|
default_includes = [
|
|
'vgg',
|
|
'efficientnet',
|
|
'resnet',
|
|
'mobilenet',
|
|
'resnext',
|
|
'wide-resnet',
|
|
'shufflenet',
|
|
'hrnet',
|
|
'resnest',
|
|
'inception',
|
|
'res2net',
|
|
'densenet',
|
|
'convnext',
|
|
'regnet',
|
|
'van',
|
|
'swin_transformer',
|
|
'convmixer',
|
|
't2t',
|
|
'twins',
|
|
'repmlp',
|
|
'tnt',
|
|
't2t',
|
|
'mlp_mixer',
|
|
'conformer',
|
|
'poolformer',
|
|
'vit',
|
|
'efficientformer',
|
|
'mobileone',
|
|
'edgenext',
|
|
'mvit',
|
|
'seresnet',
|
|
'repvgg',
|
|
'seresnext',
|
|
'deit',
|
|
'replknet',
|
|
'hornet',
|
|
'mobilevit',
|
|
'davit',
|
|
]
|
|
base_config_path = '_base_/models/'
|
|
repo = 'mmcls'
|
|
|
|
def __init__(
|
|
self,
|
|
include=default_includes,
|
|
exclude=['cutmix', 'cifar', 'gem', 'efficientformer']) -> None:
|
|
super().__init__(include=include, exclude=exclude)
|
|
|
|
|
|
class MMDetModelLibrary(MMModelLibrary):
|
|
|
|
default_includes = [
|
|
'_base',
|
|
'gfl',
|
|
'sparse',
|
|
'simple',
|
|
'pisa',
|
|
'lvis',
|
|
'carafe',
|
|
'selfsup',
|
|
'solo',
|
|
'ssd',
|
|
'res2net',
|
|
'yolof',
|
|
'reppoints',
|
|
'htc',
|
|
'groie',
|
|
'dyhead',
|
|
'grid',
|
|
'soft',
|
|
'swin',
|
|
'regnet',
|
|
'gcnet',
|
|
'ddod',
|
|
'instaboost',
|
|
'point',
|
|
'vfnet',
|
|
'pafpn',
|
|
'ghm',
|
|
'mask',
|
|
'resnest',
|
|
'tood',
|
|
'detectors',
|
|
'cornernet',
|
|
'convnext',
|
|
'cascade',
|
|
'paa',
|
|
'detr',
|
|
'rpn',
|
|
'ld',
|
|
'lad',
|
|
'ms',
|
|
'faster',
|
|
'centripetalnet',
|
|
'gn',
|
|
'dcnv2',
|
|
'legacy',
|
|
'panoptic',
|
|
'strong',
|
|
'fpg',
|
|
'deformable',
|
|
'free',
|
|
'scratch',
|
|
'openimages',
|
|
'fsaf',
|
|
'rtmdet',
|
|
'solov2',
|
|
'yolact',
|
|
'empirical',
|
|
'centernet',
|
|
'hrnet',
|
|
'guided',
|
|
'deepfashion',
|
|
'fast',
|
|
'mask2former',
|
|
'retinanet',
|
|
'autoassign',
|
|
'gn+ws',
|
|
'dcn',
|
|
'yolo',
|
|
'foveabox',
|
|
'libra',
|
|
'double',
|
|
'queryinst',
|
|
'resnet',
|
|
'nas',
|
|
'sabl',
|
|
'fcos',
|
|
'scnet',
|
|
'maskformer',
|
|
'pascal',
|
|
'cityscapes',
|
|
'timm',
|
|
'seesaw',
|
|
'pvt',
|
|
'atss',
|
|
'efficientnet',
|
|
'wider',
|
|
'tridentnet',
|
|
'dynamic',
|
|
'yolox',
|
|
'albu',
|
|
'misc',
|
|
'crowddet',
|
|
'condins',
|
|
]
|
|
base_config_path = '/'
|
|
repo = 'mmdet'
|
|
|
|
def __init__(
|
|
self,
|
|
include=default_includes,
|
|
exclude=[
|
|
'lad',
|
|
'ld',
|
|
'faster_rcnn_faster-rcnn_r50-caffe-c4_ms-1x_coco',
|
|
]
|
|
) -> None:
|
|
super().__init__(include=include, exclude=exclude)
|
|
|
|
@classmethod
|
|
def _config_process(cls, config: Dict):
|
|
config = super()._config_process(config)
|
|
if 'preprocess_cfg' in config:
|
|
config.pop('preprocess_cfg')
|
|
return config
|
|
|
|
@classmethod
|
|
def generator_type(cls):
|
|
return MMModelGenerator
|
|
|
|
|
|
class MMSegModelLibrary(MMModelLibrary):
|
|
default_includes: List = [
|
|
'_base_',
|
|
'knet',
|
|
'sem',
|
|
'dnlnet',
|
|
'dmnet',
|
|
'icnet',
|
|
'apcnet',
|
|
'swin',
|
|
'isanet',
|
|
'fastfcn',
|
|
'poolformer',
|
|
'mae',
|
|
'segformer',
|
|
'ccnet',
|
|
'twins',
|
|
'emanet',
|
|
'upernet',
|
|
'beit',
|
|
'hrnet',
|
|
'bisenetv2',
|
|
'vit',
|
|
'setr',
|
|
'cgnet',
|
|
'ocrnet',
|
|
'ann',
|
|
'erfnet',
|
|
'point',
|
|
'bisenetv1',
|
|
'nonlocal',
|
|
'unet',
|
|
'danet',
|
|
'stdc',
|
|
'fcn',
|
|
'encnet',
|
|
'resnest',
|
|
'mobilenet',
|
|
'convnext',
|
|
'deeplabv3',
|
|
'pspnet',
|
|
'gcnet',
|
|
'fastscnn',
|
|
'segmenter',
|
|
'dpt',
|
|
'deeplabv3plus',
|
|
'psanet',
|
|
]
|
|
base_config_path = '/'
|
|
repo = 'mmsegmentation'
|
|
|
|
def __init__(self, include=default_includes, exclude=['_base_']) -> None:
|
|
super().__init__(include, exclude)
|
|
|
|
@classmethod
|
|
def _config_process(cls, config: Dict):
|
|
config['_scope_'] = 'mmseg'
|
|
return config
|
|
|
|
|
|
class MMPoseModelLibrary(MMModelLibrary):
|
|
default_includes: List = [
|
|
'hand',
|
|
'face',
|
|
'wholebody',
|
|
'body',
|
|
'animal',
|
|
]
|
|
base_config_path = '/'
|
|
repo = 'mmpose'
|
|
|
|
def __init__(self, include=default_includes, exclude=[]) -> None:
|
|
super().__init__(include, exclude=exclude)
|
|
|
|
@classmethod
|
|
def _config_process(cls, config: Dict):
|
|
config['_scope_'] = 'mmpose'
|
|
return config
|
|
|
|
|
|
# tools
|
|
|
|
def revert_sync_batchnorm(module):
|
|
# this is very similar to the function that it is trying to revert:
|
|
# https://github.com/pytorch/pytorch/blob/c8b3686a3e4ba63dc59e5dcfe5db3430df256833/torch/nn/modules/batchnorm.py#L679
|
|
module_output = module
|
|
if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
|
|
new_cls = nn.BatchNorm2d
|
|
module_output = nn.BatchNorm2d(module.num_features, module.eps,
|
|
module.momentum, module.affine,
|
|
module.track_running_stats)
|
|
if module.affine:
|
|
with torch.no_grad():
|
|
module_output.weight = module.weight
|
|
module_output.bias = module.bias
|
|
module_output.running_mean = module.running_mean
|
|
module_output.running_var = module.running_var
|
|
module_output.num_batches_tracked = module.num_batches_tracked
|
|
if hasattr(module, "qconfig"):
|
|
module_output.qconfig = module.qconfig
|
|
for name, child in module.named_children():
|
|
module_output.add_module(name, revert_sync_batchnorm(child))
|
|
del module
|
|
return module_output
|