mmrazor/tests/data/model_library.py

694 lines
18 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import List
from typing import Dict, Callable
from mmrazor.registry import MODELS
from mmengine.config import Config
import os
from mmengine.utils import get_installed_path
from mmrazor.registry import MODELS
import torch
import torch.nn as nn
from .models import (AddCatModel, ConcatModel, ConvAttnModel, DwConvModel,
ExpandLineModel, GroupWiseConvModel, SingleLineModel,
MultiBindModel, MultiConcatModel, MultiConcatModel2,
ResBlock, Xmodel, MultipleUseModel, Icep, SelfAttention)
import json
# model generator
from mmdet.testing._utils import demo_mm_inputs
import string
import copy
# helper functions
def get_shape(tensor, only_length=False):
if isinstance(tensor, torch.Tensor):
if only_length:
return len(tensor.shape)
else:
return tensor.shape
elif isinstance(tensor, list) or isinstance(tensor, tuple):
shapes = []
for x in tensor:
shapes.append(get_shape(x, only_length))
return shapes
elif isinstance(tensor, dict):
shapes = {}
for key in tensor:
shapes[key] = get_shape(tensor[key], only_length)
return shapes
else:
raise NotImplementedError(
f'unsuppored type{type(tensor)} to get shape of tensors.')
# generators
class ModelGenerator(nn.Module):
def __init__(self, name: str, model_src) -> None:
super().__init__()
self.name = name
self.model_src = model_src
self._model = None
def __call__(self, *args, **kwargs):
return self.init_model()
def init_model(self):
return self.model_src()
def forward(self, x):
assert self._model is not None
return self._model(x, *self.input())
def input(self):
return []
def assert_model_is_changed(self, tensors_org, tensors_new):
shape1 = get_shape(tensors_org)
shape2 = get_shape(tensors_new)
assert shape1 == shape2, f'{shape1}!={shape2}'
def __repr__(self) -> str:
return self.name
@classmethod
def get_base_name(cls, name: str):
names = name.split('.')
return '.'.join(names[1:])
@classmethod
def get_short_name(cls, name: str):
scope = name.split('.')[0]
base_name = cls.get_base_name(name)
names = base_name.replace('-', '.').replace('_', '.').split('.')
name = names[0]
name = name.rstrip(string.digits)
return f'{scope}.{name}'
@property
def base_name(self):
return self.__class__.get_base_name(self.name)
@property
def short_name(self):
return self.__class__.get_short_name(self.name)
@property
def scope(self):
return self.name.split('.')[0]
class MMModelGenerator(ModelGenerator):
def __init__(self, name, cfg) -> None:
self.cfg = cfg
super().__init__(name, self.get_model_src)
def get_model_src(self):
model = MODELS.build(self.cfg)
model = revert_sync_batchnorm(model)
return model
def __repr__(self) -> str:
return self.name
class MMDetModelGenerator(MMModelGenerator):
def forward(self, x):
assert self._model is not None
self._model.eval()
return self._model(x, **self.input(), mode='tensor')
def input(self):
data = demo_mm_inputs(1, [[3, 224, 224]])
data = self._model.data_preprocessor(data, False)
data.pop('inputs')
return data
def assert_model_is_changed(self, tensors_org, tensors_new):
assert get_shape(tensors_org, True) == get_shape(tensors_new, True)
# model library
class ModelLibrary:
default_includes: List = []
_models = None
def __init__(self, include=default_includes, exclude=[]) -> None:
self.include_key = include
self.exclude_key = exclude
self._include_models, self._uninclude_models, self.exclude_models =\
self._classify_models(self.models)
@property
def models(self):
if self.__class__._models is None:
self.__class__._models: Dict[
str, Callable] = self.__class__.get_models()
return self.__class__._models
@classmethod
def get_models(cls):
raise NotImplementedError()
def include_models(self):
return self._include_models
def uninclude_models(self):
return self._uninclude_models
def is_include(self, name: str, includes: List[str], start_with=True):
for key in includes:
if start_with:
if name.startswith(key):
return True
else:
if key in name:
return True
return False
def is_default_includes_cover_all_models(self):
models = copy.copy(self._models)
is_covered = True
for name in models:
if self.is_include(name, self.__class__.default_includes):
pass
else:
is_covered = False
print(name, '\tnot include')
return is_covered
def short_names(self):
short_names = set()
for name in self.models:
short_names.add(self.models[name].short_name)
return short_names
def _classify_models(self, models: Dict):
include = []
uninclude = []
exclude = []
for name in models:
if self.is_include(name, self.exclude_key, start_with=False):
exclude.append(models[name])
elif self.is_include(name, self.include_key, start_with=True):
include.append(models[name])
else:
uninclude.append(models[name])
return include, uninclude, exclude
def get_short_name_of_model(self, name: str):
names = name.replace('-', '.').replace('_', '.').split('.')
return names[0]
class DefaultModelLibrary(ModelLibrary):
_mm_models = None
default_includes: List = [
'SingleLineModel',
'ResBlock',
'AddCatModel',
'ConcatModel',
'MultiConcatModel',
'MultiConcatModel2',
'GroupWiseConvModel',
'Xmodel',
'MultipleUseModel',
'Icep',
'ExpandLineModel',
'MultiBindModel',
'DwConvModel',
'ConvAttnModel',
'SelfAttention',
# mm models
'resnet',
'pspnet',
'yolo'
]
def __init__(self,
include=default_includes,
exclude=[],
with_mm_models=False) -> None:
self.with_mm_models = with_mm_models
super().__init__(include, exclude)
@property
def models(self):
models = copy.copy(super().models)
if self.with_mm_models:
models.update(self.mm_models)
return models
@property
def mm_models(self):
if self.__class__._mm_models is None:
self.__class__._mm_models = self.get_mm_models()
return self.__class__._mm_models
@classmethod
def get_models(cls):
models = [
SingleLineModel,
ResBlock,
AddCatModel,
ConcatModel,
MultiConcatModel,
MultiConcatModel2,
GroupWiseConvModel,
Xmodel,
MultipleUseModel,
Icep,
ExpandLineModel,
MultiBindModel,
DwConvModel, #
ConvAttnModel,
SelfAttention,
]
model_dict = {}
for model in models:
model_dict[model.__name__] = ModelGenerator(
'default.' + model.__name__, model)
return model_dict
@classmethod
def get_mm_models(cls):
paths = [
'mmcls::resnet/resnet34_8xb32_in1k.py',
'mmseg::pspnet/pspnet_r18-d8_4xb4-80k_potsdam-512x512.py',
'mmdet::yolo/yolov3_d53_8xb8-320-273e_coco.py'
]
models = {}
for path in paths:
Model = MMModelLibrary.get_model_from_path(path)
models[Model.base_name] = Model
return models
class TorchModelLibrary(ModelLibrary):
default_includes = [
'alexnet', 'densenet', 'efficientnet', 'googlenet', 'inception',
'mnasnet', 'mobilenet', 'regnet', 'resnet', 'resnext', 'shufflenet',
'squeezenet', 'vgg', 'wide_resnet', "vit", "swin", "convnext"
]
def __init__(self, include=default_includes, exclude=[]) -> None:
super().__init__(include, exclude)
@classmethod
def get_models(cls):
from inspect import isfunction
import torchvision
attrs = dir(torchvision.models)
models = {}
for name in attrs:
module = getattr(torchvision.models, name)
if isfunction(module) and name is not 'get_weight':
models[name] = ModelGenerator('torch.' + name, module)
return models
class MMModelLibrary(ModelLibrary):
default_includes = []
base_config_path = '/'
repo = 'mmxx'
def __init__(self, include=default_includes, exclude=[]) -> None:
super().__init__(include, exclude)
@classmethod
def scope_path(cls):
path = cls._scope_path(cls.repo) + cls.base_config_path
return path
@classmethod
def get_models(cls):
models = {}
added_models = set()
for dirpath, dirnames, filenames in os.walk(cls.scope_path()):
for filename in filenames:
if filename.endswith('.py'):
cfg_path = dirpath + '/' + filename
try:
config = Config.fromfile(cfg_path)
except:
continue
if 'model' in config:
# get model_name
model_name = cls.get_model_name_from_path(
cfg_path, cls.scope_path())
model_cfg = config['model']
model_cfg = cls._config_process(model_cfg)
if json.dumps(model_cfg) not in added_models:
models[model_name] = cls.generator_type()(
cls.repo + '.' + model_name, model_cfg)
added_models.add(json.dumps(model_cfg))
return models
@classmethod
def generator_type(cls):
return MMModelGenerator
@classmethod
def get_model_name_from_path(cls, config_path, scope_path):
import os
dirpath = os.path.dirname(config_path) + '/'
filename = os.path.basename(config_path)
model_type_name = '_'.join(dirpath.replace(scope_path, '').split('/'))
model_type_name = model_type_name if model_type_name == '' else model_type_name + '_'
model_name = model_type_name + \
os.path.basename(filename).split('.')[0]
return model_name
@classmethod
def get_model_from_path(cls, config_path):
path, scope = Config._get_cfg_path(config_path, '')
if scope is None:
scope = 'mmrazor'
config = Config.fromfile(path)['model']
config = cls._config_process(config=config)
config['_scope_'] = scope
name = cls.get_model_name_from_path(path, cls._scope_path(scope))
return cls.generator_type()(scope + '.' + name, config)
@staticmethod
def _scope_path(scope):
if scope == 'mmseg':
scope = 'mmsegmentation'
repo_path = get_installed_path(scope)
path = repo_path + '/.mim/configs/'
return path
@classmethod
def _config_process(cls, config: Dict):
config['_scope_'] = cls.repo
config = cls._remove_certain_key(config, 'init_cfg')
config = cls._remove_certain_key(config, 'pretrained')
config = cls._remove_certain_key(config, 'Pretrained')
return config
@classmethod
def _remove_certain_key(cls, config: Dict, key: str = 'init_cfg'):
if isinstance(config, dict):
if key in config:
config.pop(key)
for keyx in config:
config[keyx] = cls._remove_certain_key(config[keyx], key)
return config
class MMClsModelLibrary(MMModelLibrary):
default_includes = [
'vgg',
'efficientnet',
'resnet',
'mobilenet',
'resnext',
'wide-resnet',
'shufflenet',
'hrnet',
'resnest',
'inception',
'res2net',
'densenet',
'convnext',
'regnet',
'van',
'swin_transformer',
'convmixer',
't2t',
'twins',
'repmlp',
'tnt',
't2t',
'mlp_mixer',
'conformer',
'poolformer',
'vit',
'efficientformer',
'mobileone',
'edgenext',
'mvit',
'seresnet',
'repvgg',
'seresnext',
'deit',
'replknet',
'hornet',
'mobilevit',
'davit',
]
base_config_path = '_base_/models/'
repo = 'mmcls'
def __init__(
self,
include=default_includes,
exclude=['cutmix', 'cifar', 'gem', 'efficientformer']) -> None:
super().__init__(include=include, exclude=exclude)
class MMDetModelLibrary(MMModelLibrary):
default_includes = [
'_base',
'gfl',
'sparse',
'simple',
'pisa',
'lvis',
'carafe',
'selfsup',
'solo',
'ssd',
'res2net',
'yolof',
'reppoints',
'htc',
'groie',
'dyhead',
'grid',
'soft',
'swin',
'regnet',
'gcnet',
'ddod',
'instaboost',
'point',
'vfnet',
'pafpn',
'ghm',
'mask',
'resnest',
'tood',
'detectors',
'cornernet',
'convnext',
'cascade',
'paa',
'detr',
'rpn',
'ld',
'lad',
'ms',
'faster',
'centripetalnet',
'gn',
'dcnv2',
'legacy',
'panoptic',
'strong',
'fpg',
'deformable',
'free',
'scratch',
'openimages',
'fsaf',
'rtmdet',
'solov2',
'yolact',
'empirical',
'centernet',
'hrnet',
'guided',
'deepfashion',
'fast',
'mask2former',
'retinanet',
'autoassign',
'gn+ws',
'dcn',
'yolo',
'foveabox',
'libra',
'double',
'queryinst',
'resnet',
'nas',
'sabl',
'fcos',
'scnet',
'maskformer',
'pascal',
'cityscapes',
'timm',
'seesaw',
'pvt',
'atss',
'efficientnet',
'wider',
'tridentnet',
'dynamic',
'yolox',
'albu',
'misc',
'crowddet',
'condins',
]
base_config_path = '/'
repo = 'mmdet'
def __init__(
self,
include=default_includes,
exclude=[
'lad',
'ld',
'faster_rcnn_faster-rcnn_r50-caffe-c4_ms-1x_coco',
]
) -> None:
super().__init__(include=include, exclude=exclude)
@classmethod
def _config_process(cls, config: Dict):
config = super()._config_process(config)
if 'preprocess_cfg' in config:
config.pop('preprocess_cfg')
return config
@classmethod
def generator_type(cls):
return MMModelGenerator
class MMSegModelLibrary(MMModelLibrary):
default_includes: List = [
'_base_',
'knet',
'sem',
'dnlnet',
'dmnet',
'icnet',
'apcnet',
'swin',
'isanet',
'fastfcn',
'poolformer',
'mae',
'segformer',
'ccnet',
'twins',
'emanet',
'upernet',
'beit',
'hrnet',
'bisenetv2',
'vit',
'setr',
'cgnet',
'ocrnet',
'ann',
'erfnet',
'point',
'bisenetv1',
'nonlocal',
'unet',
'danet',
'stdc',
'fcn',
'encnet',
'resnest',
'mobilenet',
'convnext',
'deeplabv3',
'pspnet',
'gcnet',
'fastscnn',
'segmenter',
'dpt',
'deeplabv3plus',
'psanet',
]
base_config_path = '/'
repo = 'mmsegmentation'
def __init__(self, include=default_includes, exclude=['_base_']) -> None:
super().__init__(include, exclude)
@classmethod
def _config_process(cls, config: Dict):
config['_scope_'] = 'mmseg'
return config
class MMPoseModelLibrary(MMModelLibrary):
default_includes: List = [
'hand',
'face',
'wholebody',
'body',
'animal',
]
base_config_path = '/'
repo = 'mmpose'
def __init__(self, include=default_includes, exclude=[]) -> None:
super().__init__(include, exclude=exclude)
@classmethod
def _config_process(cls, config: Dict):
config['_scope_'] = 'mmpose'
return config
# tools
def revert_sync_batchnorm(module):
# this is very similar to the function that it is trying to revert:
# https://github.com/pytorch/pytorch/blob/c8b3686a3e4ba63dc59e5dcfe5db3430df256833/torch/nn/modules/batchnorm.py#L679
module_output = module
if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
new_cls = nn.BatchNorm2d
module_output = nn.BatchNorm2d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child))
del module
return module_output