mmrazor/tests/data/tracer_passed_models.py

521 lines
13 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from .model_library import (MMClsModelLibrary, MMDetModelLibrary,
DefaultModelLibrary, TorchModelLibrary,
MMPoseModelLibrary, MMSegModelLibrary)
class PassedModelManager:
def __init__(self) -> None:
pass
def include_models(self, full_test=False):
models = []
for library in self.libraries(full_test):
models.extend(library.include_models())
return models
def uninclude_models(self, full_test=False):
models = []
for library in self.libraries(full_test):
models.extend(library.uninclude_models())
return models
def libraries(self, full=False):
return []
class FxPassedModelManager(PassedModelManager):
_default_library = None
_torch_library = None
_mmcls_library = None
_mmseg_library = None
_mmdet_library = None
_mmpose_library = None
def libraries(self, full=False):
if full:
return [
self.__class__.default_library(),
self.__class__.torch_library(),
self.__class__.mmcls_library(),
self.__class__.mmseg_library(),
self.__class__.mmdet_library(),
self.__class__.mmpose_library(),
]
else:
return [self.__class__.default_library()]
@classmethod
def default_library(cls):
if cls._default_library is None:
cls._default_library = DefaultModelLibrary(include=[
'SingleLineModel',
'ResBlock',
'AddCatModel',
'ConcatModel',
'MultiConcatModel',
'MultiConcatModel2',
'GroupWiseConvModel',
'Xmodel',
'MultipleUseModel',
'Icep',
'ExpandLineModel',
'MultiBindModel',
'DwConvModel',
'ConvAttnModel',
# mm models
'resnet',
'pspnet',
'yolo'
],with_mm_models=True)
return cls._default_library
@classmethod
def torch_library(cls):
"""
googlenet: return a tuple when training, so it should
trace in eval mode
"""
torch_includes = [
'resnext',
'efficientnet',
'inception',
'wide',
'resnet',
'regnet',
'shufflenet',
'mnasnet',
'vit',
'convnext',
'googlenet',
'densenet',
'swin',
'vgg',
'mobilenet',
'squeezenet',
'alexnet',
]
if cls._torch_library is None:
cls._torch_library = TorchModelLibrary(include=torch_includes)
return cls._torch_library
@classmethod
def mmcls_library(cls):
"""
shufflenet consists of chunk operations.
resnest: resnest has two problems. First it uses *x.shape() which is
not tracerable using fx tracer. Second, it uses channel folding.
res2net: res2net consists of split operations.
convnext: consist of layernorm.
"""
mmcls_include = [
'tnt',
'resnet',
'resnetv1c',
'mobileone',
'mlp',
'densenet',
'hrnet',
'seresnet',
'van',
'repmlp',
'repvgg',
'vgg',
'vgg11bn',
'edgenext',
'vgg19bn',
'wide',
'res2net',
'vgg13bn',
'resnetv1d',
'mobilenet',
'convmixer',
'resnest',
'inception',
'resnext',
'twins',
'vgg16bn',
'shufflenet',
'conformer',
'regnet',
'seresnext',
'vit',
'poolformer',
't2t',
'efficientnet',
## error
# 'deit',
# 'swin',
# 'convnext',
# 'mvit'
]
if cls._mmcls_library is None:
cls._mmcls_library = MMClsModelLibrary(include=mmcls_include)
return cls._mmcls_library
@classmethod
def mmdet_library(cls):
mmdet_include = [
'pafpn',
'gn+ws',
'paa',
'fcos',
'autoassign',
'centripetalnet',
'retinanet',
'cornernet',
'gn',
'instaboost',
'rpn',
'fpg',
'crowddet',
'resnest',
'pvt',
'solo',
'grid',
'free',
'point',
'yolo',
'double',
'dynamic',
'maskformer',
'scratch',
'nas',
'yolof',
'faster',
'atss',
'yolox',
'fsaf',
'ghm',
'centernet',
'seesaw',
'regnet',
'cityscapes',
'lvis',
'sabl',
'gfl',
'tridentnet',
'selfsup',
'deepfashion',
'efficientnet',
'foveabox',
'mask',
## errors
# 'timm',
# 'swin',
# 'dyhead',
# 'hrnet',
# 'deformable',
# 'ssd',
# 'empirical',
# 'detectors',
# 'reppoints',
# 'scnet',
# 'legacy',
# 'htc',
# 'dcnv',
# 'carafe',
# 'yolact',
# 'panoptic',
# 'misc',
# 'rtmdet',
# 'pascal',
# 'ddod',
# 'mask2former',
# 'tood',
# 'queryinst',
# 'simple',
# 'pisa',
# 'fast',
# 'cascade',
# 'wider',
# 'openimages',
# '',
# 'strong',
# 'res2net',
# 'libra',
# 'vfnet',
# 'soft',
# 'sparse',
# 'gcnet',
# 'convnext',
# 'ms',
# 'dcn',
# 'guided',
# 'groie',
# 'solov',
# 'detr',
]
if cls._mmdet_library is None:
cls._mmdet_library = MMDetModelLibrary(mmdet_include)
return cls._mmdet_library
@classmethod
def mmseg_library(cls):
# a common error: unet related models
include = [
'bisenetv',
'erfnet',
'dmnet',
'twins',
'segformer',
'isanet',
'vit',
'resnest',
'setr',
'cgnet',
'stdc',
'dpt',
'pspnet',
'upernet',
'apcnet',
'gcnet',
'ann',
'ocrnet',
'ccnet',
'deeplabv',
'dnlnet',
'point',
'fastscnn',
'psanet',
'segmenter',
'danet',
'emanet',
'icnet',
'unet',
'fcn',
'swin',
'nonlocal',
'deeplabv3plus',
'sem',
## errors
# 'mobilenet',
# 'mae',
# 'knet',
# 'poolformer',
# 'beit',
# 'encnet',
# 'hrnet',
# 'convnext',
# 'fastfcn'
]
if cls._mmseg_library is None:
cls._mmseg_library = MMSegModelLibrary(include=include)
return cls._mmseg_library
@classmethod
def mmpose_library(cls):
mmpose_include = [
'hand',
'face',
'wholebody',
'body',
'animal',
]
if cls._mmpose_library is None:
cls._mmpose_library = MMPoseModelLibrary(include=mmpose_include)
return cls._mmpose_library
# for backward tracer
class BackwardPassedModelManager(PassedModelManager):
_default_library = None
_torch_library = None
_mmcls_library = None
_mmseg_library = None
_mmdet_library = None
_mmpose_library = None
def libraries(self, full=False):
if full:
return [
self.__class__.default_library(),
self.__class__.torch_library(),
self.__class__.mmcls_library(),
self.__class__.mmseg_library(),
self.__class__.mmdet_library(),
self.__class__.mmpose_library(),
]
else:
return [self.__class__.default_library()]
@classmethod
def default_library(cls):
if cls._default_library is None:
cls._default_library = DefaultModelLibrary(include=[
'SingleLineModel',
'ResBlock',
'AddCatModel',
'ConcatModel',
'MultiConcatModel',
'MultiConcatModel2',
'GroupWiseConvModel',
'Xmodel',
# 'MultipleUseModel', # bug
'Icep',
'ExpandLineModel',
'MultiBindModel',
'DwConvModel',
'ConvAttnModel',
])
return cls._default_library
@classmethod
def torch_library(cls):
"""
googlenet return a tuple when training, so it
should trace in eval mode
"""
torch_includes = [
'alexnet',
'densenet',
'efficientnet',
'googlenet',
'inception',
'mnasnet',
'mobilenet',
'regnet',
'resnet',
'resnext',
# 'shufflenet', # bug
'squeezenet',
'vgg',
'wide_resnet',
# "vit",
# "swin",
# "convnext"
]
if cls._torch_library is None:
cls._torch_library = TorchModelLibrary(include=torch_includes)
return cls._torch_library
@classmethod
def mmcls_library(cls):
"""
shufflenet consists of chunk operations.
resnest: resnest has two problems. First it uses *x.shape() which is
not tracerable using fx tracer. Second, it uses channel folding.
res2net: res2net consists of split operations.
convnext: consist of layernorm.
"""
mmcls_model_include = [
'vgg',
'efficientnet',
'resnet',
'mobilenet',
'resnext',
'wide-resnet',
# 'shufflenet', # bug
'hrnet',
# 'resnest', # bug
'inception',
# 'res2net', # bug
'densenet',
# 'convnext', # bug
'regnet',
# 'van', # bug
# 'swin_transformer', # bug
# 'convmixer', # bug
# 't2t', # bug
# 'twins', # bug
# 'repmlp', # bug
# 'tnt', # bug
# 't2t', # bug
# 'mlp_mixer', # bug
# 'conformer', # bug
# 'poolformer', # bug
# 'vit', # bug
# 'efficientformer',
# 'mobileone',
# 'edgenext'
]
mmcls_exclude = ['cutmix', 'cifar', 'gem']
if cls._mmcls_library is None:
cls._mmcls_library = MMClsModelLibrary(
include=mmcls_model_include, exclude=mmcls_exclude)
return cls._mmcls_library
@classmethod
def mmdet_library(cls):
mmdet_include = [
# 'rpn', #
# 'faster-rcnn',
# 'cascade-rcnn',
# 'fast-rcnn', # mmdet has bug.
# 'retinanet',
# 'mask-rcnn',
# 'ssd300'
]
if cls._mmdet_library is None:
cls._mmdet_library = MMDetModelLibrary(mmdet_include)
return cls._mmdet_library
@classmethod
def mmseg_library(cls):
include = [
# 'cgnet',
# 'gcnet',
# 'setr',
# 'deeplabv3',
# 'twins',
# 'fastfcn',
# 'fpn',
# 'upernet',
# 'dnl',
# 'icnet',
# 'segmenter',
# 'encnet',
# 'erfnet',
# 'segformer',
# 'apcnet',
# 'fast',
# 'ocrnet',
# 'lraspp',
# 'dpt',
# 'fcn',
# 'psanet',
# 'bisenetv2',
# 'pointrend',
# 'ccnet',
'pspnet',
# 'dmnet',
# 'stdc',
# 'ann',
# 'nonlocal',
# 'isanet',
# 'danet',
# 'emanet',
# 'deeplabv3plus',
# 'bisenetv1',
]
if cls._mmseg_library is None:
cls._mmseg_library = MMSegModelLibrary(include=include)
return cls._mmseg_library
@classmethod
def mmpose_library(cls):
mmpose_include = [
'hand',
'face',
'wholebody',
'body',
'animal',
]
if cls._mmpose_library is None:
cls._mmpose_library = MMPoseModelLibrary(include=mmpose_include)
return cls._mmpose_library
fx_passed_library = FxPassedModelManager()
backward_passed_library = BackwardPassedModelManager()