Rename the package name to `mmpretrain`.
parent
8352951f3d
commit
0979e78573
|
@ -17,10 +17,10 @@ from modelindex.load_model_index import load
|
|||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from mmcls.apis import init_model
|
||||
from mmcls.datasets import CIFAR10, CIFAR100, ImageNet
|
||||
from mmcls.utils import register_all_modules
|
||||
from mmcls.visualization import ClsVisualizer
|
||||
from mmpretrain.apis import init_model
|
||||
from mmpretrain.datasets import CIFAR10, CIFAR100, ImageNet
|
||||
from mmpretrain.utils import register_all_modules
|
||||
from mmpretrain.visualization import ClsVisualizer
|
||||
|
||||
console = Console()
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
||||
|
|
|
@ -18,9 +18,9 @@ from modelindex.load_model_index import load
|
|||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from mmcls.datasets.builder import build_dataloader
|
||||
from mmcls.datasets.pipelines import Compose
|
||||
from mmcls.models.builder import build_classifier
|
||||
from mmpretrain.datasets.builder import build_dataloader
|
||||
from mmpretrain.datasets.pipelines import Compose
|
||||
from mmpretrain.models.builder import build_classifier
|
||||
|
||||
console = Console()
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
||||
|
|
|
@ -47,7 +47,7 @@ def ckpt_to_state_dict(checkpoint, key=None):
|
|||
if key is not None:
|
||||
state_dict = checkpoint[key]
|
||||
elif 'state_dict' in checkpoint:
|
||||
# try mmcls style
|
||||
# try mmpretrain style
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
|
@ -149,7 +149,7 @@ def main():
|
|||
if args.path.suffix in ['.json', '.py', '.yml']:
|
||||
from mmengine.runner import get_state_dict
|
||||
|
||||
from mmcls.apis import init_model
|
||||
from mmpretrain.apis import init_model
|
||||
model = init_model(args.path, device='cpu')
|
||||
state_dict = get_state_dict(model)
|
||||
else:
|
||||
|
|
|
@ -55,7 +55,7 @@ def state_dict_from_cfg_or_ckpt(path, state_key=None):
|
|||
if path.suffix in ['.json', '.py', '.yml']:
|
||||
from mmengine.runner import get_state_dict
|
||||
|
||||
from mmcls.apis import init_model
|
||||
from mmpretrain.apis import init_model
|
||||
model = init_model(path, device='cpu')
|
||||
model.init_weights()
|
||||
return get_state_dict(model)
|
||||
|
|
|
@ -84,8 +84,8 @@ def get_flops(config_path):
|
|||
from mmengine.dataset import Compose
|
||||
from mmengine.registry import DefaultScope
|
||||
|
||||
import mmcls.datasets # noqa: F401
|
||||
from mmcls.apis import init_model
|
||||
import mmpretrain.datasets # noqa: F401
|
||||
from mmpretrain.apis import init_model
|
||||
|
||||
cfg = Config.fromfile(config_path)
|
||||
|
||||
|
@ -98,7 +98,7 @@ def get_flops(config_path):
|
|||
# The image shape of CIFAR is (32, 32, 3)
|
||||
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
|
||||
|
||||
with DefaultScope.overwrite_default_scope('mmcls'):
|
||||
with DefaultScope.overwrite_default_scope('mmpretrain'):
|
||||
data = Compose(test_dataset.pipeline)({
|
||||
'img':
|
||||
np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
|
||||
|
|
|
@ -127,7 +127,7 @@ venv.bak/
|
|||
/work_dirs
|
||||
/projects/*/work_dirs
|
||||
/projects/*/data
|
||||
/mmcls/.mim
|
||||
/mmpretrain/.mim
|
||||
.DS_Store
|
||||
|
||||
# Pytorch
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
include requirements/*.txt
|
||||
include mmcls/.mim/model-index.yml
|
||||
recursive-include mmcls/.mim/configs *.py *.yml
|
||||
recursive-include mmcls/.mim/tools *.py *.sh
|
||||
include mmpretrain/.mim/model-index.yml
|
||||
recursive-include mmpretrain/.mim/configs *.py *.yml
|
||||
recursive-include mmpretrain/.mim/tools *.py *.sh
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# defaults to use registries in mmcls
|
||||
default_scope = 'mmcls'
|
||||
# defaults to use registries in mmpretrain
|
||||
default_scope = 'mmpretrain'
|
||||
|
||||
# configure default hooks
|
||||
default_hooks = dict(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from mmcls.models import build_classifier
|
||||
from mmpretrain.models import build_classifier
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
|
|
|
@ -2,11 +2,11 @@ _base_ = ['../_base_/datasets/voc_bs16.py', '../_base_/default_runtime.py']
|
|||
|
||||
# Pre-trained Checkpoint Path
|
||||
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth' # noqa
|
||||
# If you want to use the pre-trained weight of ResNet101-CutMix from
|
||||
# the originary repo(https://github.com/Kevinz-code/CSRA). Script of
|
||||
# 'tools/convert_models/torchvision_to_mmcls.py' can help you convert weight
|
||||
# into mmcls format. The mAP result would hit 95.5 by using the weight.
|
||||
# checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT'
|
||||
# If you want to use the pre-trained weight of ResNet101-CutMix from the
|
||||
# originary repo(https://github.com/Kevinz-code/CSRA). Script of
|
||||
# 'tools/model_converters/torchvision_to_mmpretrain.py' can help you convert
|
||||
# weight into mmpretrain format. The mAP result would hit 95.5 by using the
|
||||
# weight. checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
|
|
|
@ -51,7 +51,7 @@ val_cfg = dict()
|
|||
test_cfg = dict()
|
||||
|
||||
# runtime settings
|
||||
default_scope = 'mmcls'
|
||||
default_scope = 'mmpretrain'
|
||||
|
||||
default_hooks = dict(
|
||||
# record the time of every iteration.
|
||||
|
|
|
@ -4,7 +4,7 @@ from argparse import ArgumentParser
|
|||
from mmengine.fileio import dump
|
||||
from rich import print_json
|
||||
|
||||
from mmcls.apis import ImageClassificationInferencer
|
||||
from mmpretrain.apis import ImageClassificationInferencer
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -30,7 +30,7 @@ def main():
|
|||
raise ValueError(
|
||||
f'Unavailable model "{args.model}", you can specify find a model '
|
||||
'name or a config file or find a model name from '
|
||||
'https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html#all-checkpoints' # noqa: E501
|
||||
'https://mmpretrain.readthedocs.io/en/1.x/modelzoo_statistics.html#all-checkpoints' # noqa: E501
|
||||
)
|
||||
result = inferencer(args.img, show=args.show, show_dir=args.show_dir)[0]
|
||||
# show the results
|
||||
|
|
|
@ -22,12 +22,12 @@ sys.path.insert(0, os.path.abspath('../../'))
|
|||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'MMClassification'
|
||||
project = 'MMPretrain'
|
||||
copyright = '2020, OpenMMLab'
|
||||
author = 'MMClassification Authors'
|
||||
author = 'MMPretrain Authors'
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
version_file = '../../mmcls/version.py'
|
||||
version_file = '../../mmpretrain/version.py'
|
||||
|
||||
|
||||
def get_version():
|
||||
|
@ -92,25 +92,25 @@ html_theme_options = {
|
|||
'menu': [
|
||||
{
|
||||
'name': 'GitHub',
|
||||
'url': 'https://github.com/open-mmlab/mmclassification'
|
||||
'url': 'https://github.com/open-mmlab/mmpretrain'
|
||||
},
|
||||
{
|
||||
'name': 'Colab Tutorials',
|
||||
'children': [
|
||||
{'name': 'Train and inference with shell commands',
|
||||
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_tools.ipynb'},
|
||||
'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_tools.ipynb'},
|
||||
{'name': 'Train and inference with Python APIs',
|
||||
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_python.ipynb'},
|
||||
'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_python.ipynb'},
|
||||
]
|
||||
},
|
||||
{
|
||||
'name': 'Version',
|
||||
'children': [
|
||||
{'name': 'MMClassification 0.x',
|
||||
'url': 'https://mmclassification.readthedocs.io/en/latest/',
|
||||
{'name': 'MMPretrain 0.x',
|
||||
'url': 'https://mmpretrain.readthedocs.io/en/latest/',
|
||||
'description': 'master branch'},
|
||||
{'name': 'MMClassification 1.x',
|
||||
'url': 'https://mmclassification.readthedocs.io/en/dev-1.x/',
|
||||
{'name': 'MMPretrain 1.x',
|
||||
'url': 'https://mmpretrain.readthedocs.io/en/dev-1.x/',
|
||||
'description': '1.x branch'},
|
||||
],
|
||||
}
|
||||
|
@ -138,7 +138,7 @@ html_js_files = [
|
|||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'mmclsdoc'
|
||||
htmlhelp_basename = 'mmpretraindoc'
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
|
@ -160,16 +160,14 @@ latex_elements = {
|
|||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(root_doc, 'mmcls.tex', 'MMClassification Documentation', author,
|
||||
'manual'),
|
||||
(root_doc, 'mmpretrain.tex', 'MMPretrain Documentation', author, 'manual'),
|
||||
]
|
||||
|
||||
# -- Options for manual page output ------------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1)
|
||||
]
|
||||
man_pages = [(root_doc, 'mmpretrain', 'MMPretrain Documentation', [author], 1)]
|
||||
|
||||
# -- Options for Texinfo output ----------------------------------------------
|
||||
|
||||
|
@ -177,7 +175,7 @@ man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1)
|
|||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(root_doc, 'mmcls', 'MMClassification Documentation', author, 'mmcls',
|
||||
(root_doc, 'mmpretrain', 'MMPretrain Documentation', author, 'mmpretrain',
|
||||
'OpenMMLab image classification toolbox and benchmark.', 'Miscellaneous'),
|
||||
]
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from tabulate import tabulate
|
|||
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
||||
PAPERS_ROOT = Path('papers') # Path to save generated paper pages.
|
||||
GITHUB_PREFIX = 'https://github.com/open-mmlab/mmclassification/blob/1.x/'
|
||||
GITHUB_PREFIX = 'https://github.com/open-mmlab/mmpretrain/blob/1.x/'
|
||||
MODELZOO_TEMPLATE = """
|
||||
# Model Zoo Summary
|
||||
|
||||
|
|
|
@ -22,12 +22,12 @@ sys.path.insert(0, os.path.abspath('../..'))
|
|||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'MMClassification'
|
||||
project = 'MMPretrain'
|
||||
copyright = '2020, OpenMMLab'
|
||||
author = 'MMClassification Authors'
|
||||
author = 'MMPretrain Authors'
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
version_file = '../../mmcls/version.py'
|
||||
version_file = '../../mmpretrain/version.py'
|
||||
|
||||
|
||||
def get_version():
|
||||
|
@ -92,25 +92,25 @@ html_theme_options = {
|
|||
'menu': [
|
||||
{
|
||||
'name': 'GitHub',
|
||||
'url': 'https://github.com/open-mmlab/mmclassification'
|
||||
'url': 'https://github.com/open-mmlab/mmpretrain'
|
||||
},
|
||||
{
|
||||
'name': 'Colab 教程',
|
||||
'children': [
|
||||
{'name': '用命令行工具训练和推理',
|
||||
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_tools.ipynb'},
|
||||
'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_tools.ipynb'},
|
||||
{'name': '用 Python API 训练和推理',
|
||||
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_python.ipynb'},
|
||||
'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_python.ipynb'},
|
||||
]
|
||||
},
|
||||
{
|
||||
'name': 'Version',
|
||||
'children': [
|
||||
{'name': 'MMClassification 0.x',
|
||||
'url': 'https://mmclassification.readthedocs.io/zh_CN/latest/',
|
||||
{'name': 'MMPretrain 0.x',
|
||||
'url': 'https://mmpretrain.readthedocs.io/zh_CN/latest/',
|
||||
'description': 'master branch'},
|
||||
{'name': 'MMClassification 1.x',
|
||||
'url': 'https://mmclassification.readthedocs.io/zh_CN/dev-1.x/',
|
||||
{'name': 'MMPretrain 1.x',
|
||||
'url': 'https://mmpretrain.readthedocs.io/zh_CN/dev-1.x/',
|
||||
'description': '1.x branch'},
|
||||
],
|
||||
}
|
||||
|
@ -138,7 +138,7 @@ html_js_files = [
|
|||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'mmclsdoc'
|
||||
htmlhelp_basename = 'mmpretraindoc'
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
|
@ -164,16 +164,14 @@ latex_elements = {
|
|||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(root_doc, 'mmcls.tex', 'MMClassification Documentation', author,
|
||||
'manual'),
|
||||
(root_doc, 'mmpretrain.tex', 'MMPretrain Documentation', author, 'manual'),
|
||||
]
|
||||
|
||||
# -- Options for manual page output ------------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1)
|
||||
]
|
||||
man_pages = [(root_doc, 'mmpretrain', 'MMPretrain Documentation', [author], 1)]
|
||||
|
||||
# -- Options for Texinfo output ----------------------------------------------
|
||||
|
||||
|
@ -181,7 +179,7 @@ man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1)
|
|||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(root_doc, 'mmcls', 'MMClassification Documentation', author, 'mmcls',
|
||||
(root_doc, 'mmpretrain', 'MMPretrain Documentation', author, 'mmpretrain',
|
||||
'OpenMMLab image classification toolbox and benchmark.', 'Miscellaneous'),
|
||||
]
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from tabulate import tabulate
|
|||
|
||||
MMCLS_ROOT = Path(__file__).absolute().parents[2]
|
||||
PAPERS_ROOT = Path('papers') # Path to save generated paper pages.
|
||||
GITHUB_PREFIX = 'https://github.com/open-mmlab/mmclassification/blob/1.x/'
|
||||
GITHUB_PREFIX = 'https://github.com/open-mmlab/mmpretrain/blob/1.x/'
|
||||
MODELZOO_TEMPLATE = """
|
||||
# 模型库统计
|
||||
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import warnings
|
||||
|
||||
from mmengine import DefaultScope
|
||||
|
||||
|
||||
def register_all_modules(init_default_scope: bool = True) -> None:
|
||||
"""Register all modules in mmcls into the registries.
|
||||
|
||||
Args:
|
||||
init_default_scope (bool): Whether initialize the mmcls default scope.
|
||||
If True, the global default scope will be set to `mmcls`, and all
|
||||
registries will build modules from mmcls's registry node. To
|
||||
understand more about the registry, please refer to
|
||||
https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
|
||||
Defaults to True.
|
||||
""" # noqa
|
||||
import mmcls.datasets # noqa: F401,F403
|
||||
import mmcls.engine # noqa: F401,F403
|
||||
import mmcls.evaluation # noqa: F401,F403
|
||||
import mmcls.models # noqa: F401,F403
|
||||
import mmcls.structures # noqa: F401,F403
|
||||
import mmcls.visualization # noqa: F401,F403
|
||||
|
||||
if not init_default_scope:
|
||||
return
|
||||
|
||||
current_scope = DefaultScope.get_current_instance()
|
||||
if current_scope is None:
|
||||
DefaultScope.get_instance('mmcls', scope_name='mmcls')
|
||||
elif current_scope.scope_name != 'mmcls':
|
||||
warnings.warn(f'The current default scope "{current_scope.scope_name}"'
|
||||
' is not "mmcls", `register_all_modules` will force the '
|
||||
'current default scope to be "mmcls". If this is not '
|
||||
'expected, please set `init_default_scope=False`.')
|
||||
# avoid name conflict
|
||||
new_instance_name = f'mmcls-{datetime.datetime.now()}'
|
||||
DefaultScope.get_instance(new_instance_name, scope_name='mmcls')
|
|
@ -11,8 +11,8 @@ from mmengine.infer import BaseInferencer
|
|||
from mmengine.model import BaseModel
|
||||
from mmengine.runner import load_checkpoint
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from .model import get_model, init_model, list_models
|
||||
|
||||
ModelType = Union[BaseModel, str, Config]
|
||||
|
@ -63,7 +63,7 @@ class ImageClassificationInferencer(BaseInferencer):
|
|||
Example:
|
||||
1. Use a pre-trained model in MMClassification to inference an image.
|
||||
|
||||
>>> from mmcls import ImageClassificationInferencer
|
||||
>>> from mmpretrain import ImageClassificationInferencer
|
||||
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
|
||||
>>> inferencer('demo/demo.JPEG')
|
||||
[{'pred_score': array([...]),
|
||||
|
@ -74,7 +74,7 @@ class ImageClassificationInferencer(BaseInferencer):
|
|||
2. Use a config file and checkpoint to inference multiple images on GPU,
|
||||
and save the visualization results in a folder.
|
||||
|
||||
>>> from mmcls import ImageClassificationInferencer
|
||||
>>> from mmpretrain import ImageClassificationInferencer
|
||||
>>> inferencer = ImageClassificationInferencer(
|
||||
model='configs/resnet/resnet50_8xb32_in1k.py',
|
||||
weights='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
|
||||
|
@ -107,7 +107,7 @@ class ImageClassificationInferencer(BaseInferencer):
|
|||
else:
|
||||
raise TypeError(
|
||||
'The `model` can be a name of model and you can use '
|
||||
'`mmcls.list_models` to get an available name. It can '
|
||||
'`mmpretrain.list_models` to get an available name. It can '
|
||||
'also be a Config object or a path to the config file.')
|
||||
|
||||
model.eval()
|
||||
|
@ -185,7 +185,7 @@ class ImageClassificationInferencer(BaseInferencer):
|
|||
return None
|
||||
|
||||
if self.visualizer is None:
|
||||
from mmcls.visualization import ClsVisualizer
|
||||
from mmpretrain.visualization import ClsVisualizer
|
||||
self.visualizer = ClsVisualizer()
|
||||
if self.classes is not None:
|
||||
self.visualizer._dataset_meta = dict(classes=self.classes)
|
|
@ -15,7 +15,7 @@ from modelindex.models.Model import Model
|
|||
class ModelHub:
|
||||
"""A hub to host the meta information of all pre-defined models."""
|
||||
_models_dict = {}
|
||||
__mmcls_registered = False
|
||||
__mmpretrain_registered = False
|
||||
|
||||
@classmethod
|
||||
def register_model_index(cls,
|
||||
|
@ -52,7 +52,7 @@ class ModelHub:
|
|||
Returns:
|
||||
modelindex.models.Model: The metainfo of the specified model.
|
||||
"""
|
||||
cls._register_mmcls_models()
|
||||
cls._register_mmpretrain_models()
|
||||
# lazy load config
|
||||
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
|
||||
if metainfo is None:
|
||||
|
@ -75,15 +75,15 @@ class ModelHub:
|
|||
return config_path
|
||||
|
||||
@classmethod
|
||||
def _register_mmcls_models(cls):
|
||||
# register models in mmcls
|
||||
if not cls.__mmcls_registered:
|
||||
def _register_mmpretrain_models(cls):
|
||||
# register models in mmpretrain
|
||||
if not cls.__mmpretrain_registered:
|
||||
from mmengine.utils import get_installed_path
|
||||
mmcls_root = Path(get_installed_path('mmcls'))
|
||||
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
|
||||
mmpretrain_root = Path(get_installed_path('mmpretrain'))
|
||||
model_index_path = mmpretrain_root / '.mim' / 'model-index.yml'
|
||||
ModelHub.register_model_index(
|
||||
model_index_path, config_prefix=mmcls_root / '.mim')
|
||||
cls.__mmcls_registered = True
|
||||
model_index_path, config_prefix=mmpretrain_root / '.mim')
|
||||
cls.__mmpretrain_registered = True
|
||||
|
||||
@classmethod
|
||||
def has(cls, model_name):
|
||||
|
@ -118,7 +118,7 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
|
|||
config.model.setdefault('data_preprocessor',
|
||||
config.get('data_preprocessor', None))
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
|
@ -130,13 +130,13 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
|
|||
# Don't set CLASSES if the model is headless.
|
||||
pass
|
||||
elif 'dataset_meta' in checkpoint.get('meta', {}):
|
||||
# mmcls 1.x
|
||||
# mmpretrain 1.x
|
||||
model.CLASSES = checkpoint['meta']['dataset_meta'].get('classes')
|
||||
elif 'CLASSES' in checkpoint.get('meta', {}):
|
||||
# mmcls < 1.x
|
||||
# mmpretrain < 1.x or mmselfsup < 1.x
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
else:
|
||||
from mmcls.datasets.categories import IMAGENET_CATEGORIES
|
||||
from mmpretrain.datasets.categories import IMAGENET_CATEGORIES
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn('Class names are not saved in the checkpoint\'s '
|
||||
'meta data, use imagenet by default.')
|
||||
|
@ -165,7 +165,7 @@ def get_model(model_name, pretrained=False, device=None, **kwargs):
|
|||
Get a ResNet-50 model and extract images feature:
|
||||
|
||||
>>> import torch
|
||||
>>> from mmcls import get_model
|
||||
>>> from mmpretrain import get_model
|
||||
>>> inputs = torch.rand(16, 3, 224, 224)
|
||||
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
|
||||
>>> feats = model.extract_feat(inputs)
|
||||
|
@ -178,7 +178,7 @@ def get_model(model_name, pretrained=False, device=None, **kwargs):
|
|||
|
||||
Get Swin-Transformer model with pre-trained weights and inference:
|
||||
|
||||
>>> from mmcls import get_model, inference_model
|
||||
>>> from mmpretrain import get_model, inference_model
|
||||
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
|
||||
>>> result = inference_model(model, 'demo/demo.JPEG')
|
||||
>>> print(result['pred_class'])
|
||||
|
@ -201,7 +201,7 @@ def get_model(model_name, pretrained=False, device=None, **kwargs):
|
|||
|
||||
|
||||
def list_models(pattern=None) -> List[str]:
|
||||
"""List all models available in MMClassification.
|
||||
"""List all models available in MMPretrain.
|
||||
|
||||
Args:
|
||||
pattern (str | None): A wildcard pattern to match model names.
|
||||
|
@ -212,12 +212,12 @@ def list_models(pattern=None) -> List[str]:
|
|||
Examples:
|
||||
List all models:
|
||||
|
||||
>>> from mmcls import list_models
|
||||
>>> from mmpretrain import list_models
|
||||
>>> print(list_models())
|
||||
|
||||
List ResNet-50 models on ImageNet-1k dataset:
|
||||
|
||||
>>> from mmcls import list_models
|
||||
>>> from mmpretrain import list_models
|
||||
>>> print(list_models('resnet*in1k'))
|
||||
['resnet50_8xb32_in1k',
|
||||
'resnet50_8xb32-fp16_in1k',
|
||||
|
@ -225,7 +225,7 @@ def list_models(pattern=None) -> List[str]:
|
|||
'resnet50_8xb256-rsb-a2-300e_in1k',
|
||||
'resnet50_8xb256-rsb-a3-100e_in1k']
|
||||
"""
|
||||
ModelHub._register_mmcls_models()
|
||||
ModelHub._register_mmpretrain_models()
|
||||
if pattern is None:
|
||||
return sorted(list(ModelHub._models_dict.keys()))
|
||||
# Always match keys with any postfix.
|
|
@ -7,7 +7,7 @@ import mmengine
|
|||
import numpy as np
|
||||
from mmengine.dataset import BaseDataset as _BaseDataset
|
||||
|
||||
from mmcls.registry import DATASETS, TRANSFORMS
|
||||
from mmpretrain.registry import DATASETS, TRANSFORMS
|
||||
|
||||
|
||||
def expanduser(path):
|
|
@ -1,12 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcls.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS
|
||||
|
||||
|
||||
def build_dataset(cfg):
|
||||
"""Build dataset.
|
||||
|
||||
Examples:
|
||||
>>> from mmcls.datasets import build_dataset
|
||||
>>> from mmpretrain.datasets import build_dataset
|
||||
>>> mnist_train = build_dataset(
|
||||
... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False))
|
||||
>>> print(mnist_train)
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
from mmengine.fileio import (LocalBackend, exists, get, get_file_backend,
|
||||
join_path)
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
from .categories import CIFAR10_CATEGORIES, CIFAR100_CATEGORIES
|
||||
from .utils import check_md5, download_and_extract_archive
|
|
@ -3,7 +3,7 @@ from typing import List
|
|||
|
||||
from mmengine import get_file_backend, list_from_file
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
from .categories import CUB_CATEGORIES
|
||||
|
||||
|
@ -51,7 +51,7 @@ class CUB(BaseDataset):
|
|||
|
||||
|
||||
Examples:
|
||||
>>> from mmcls.datasets import CUB
|
||||
>>> from mmpretrain.datasets import CUB
|
||||
>>> cub_train_cfg = dict(data_root='data/CUB_200_2011', test_mode=True)
|
||||
>>> cub_train = CUB(**cub_train_cfg)
|
||||
>>> cub_train
|
|
@ -5,7 +5,7 @@ from mmengine.fileio import (BaseStorageBackend, get_file_backend,
|
|||
list_from_file)
|
||||
from mmengine.logging import MMLogger
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
|
@ -4,7 +4,7 @@ import copy
|
|||
import numpy as np
|
||||
from mmengine.dataset import BaseDataset, force_full_init
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
|
@ -3,7 +3,7 @@ from typing import Optional, Union
|
|||
|
||||
from mmengine.logging import MMLogger
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS
|
||||
from .categories import IMAGENET_CATEGORIES
|
||||
from .custom import CustomDataset
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine import get_file_backend, list_from_file
|
||||
|
||||
from mmcls.datasets.base_dataset import BaseDataset
|
||||
from mmcls.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -35,7 +35,7 @@ class InShop(BaseDataset):
|
|||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
|
||||
Examples:
|
||||
>>> from mmcls.datasets import InShop
|
||||
>>> from mmpretrain.datasets import InShop
|
||||
>>>
|
||||
>>> # build train InShop dataset
|
||||
>>> inshop_train_cfg = dict(data_root='data/inshop', split='train')
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
import torch
|
||||
from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
|
||||
from .utils import (download_and_extract_archive, open_maybe_compressed_file,
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
|
@ -60,7 +60,7 @@ class MultiTaskDataset:
|
|||
Assume we put our dataset in the ``data/mydataset`` folder in the
|
||||
repository and organize it as the below format: ::
|
||||
|
||||
mmclassification/
|
||||
mmpretrain/
|
||||
└── data
|
||||
└── mydataset
|
||||
├── annotation
|
||||
|
@ -81,7 +81,7 @@ class MultiTaskDataset:
|
|||
|
||||
.. code:: python
|
||||
|
||||
>>> from mmcls.datasets import build_dataset
|
||||
>>> from mmpretrain.datasets import build_dataset
|
||||
>>> train_cfg = dict(
|
||||
... type="MultiTaskDataset",
|
||||
... ann_file="annotation/train.json",
|
||||
|
@ -94,7 +94,7 @@ class MultiTaskDataset:
|
|||
|
||||
Or we can put all files in the same folder: ::
|
||||
|
||||
mmclassification/
|
||||
mmpretrain/
|
||||
└── data
|
||||
└── mydataset
|
||||
├── train.json
|
||||
|
@ -109,7 +109,7 @@ class MultiTaskDataset:
|
|||
|
||||
.. code:: python
|
||||
|
||||
>>> from mmcls.datasets import build_dataset
|
||||
>>> from mmpretrain.datasets import build_dataset
|
||||
>>> train_cfg = dict(
|
||||
... type="MultiTaskDataset",
|
||||
... ann_file="train.json",
|
||||
|
@ -133,8 +133,8 @@ class MultiTaskDataset:
|
|||
``data_root`` for the ``"img_path"`` field in the annotation file.
|
||||
Defaults to None.
|
||||
pipeline (Sequence[dict]): A list of dict, where each element
|
||||
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
|
||||
Defaults to an empty tuple.
|
||||
represents a operation defined in
|
||||
:mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple.
|
||||
test_mode (bool): in train mode or test mode. Defaults to False.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from mmcls.registry import DATA_SAMPLERS
|
||||
from mmpretrain.registry import DATA_SAMPLERS
|
||||
|
||||
|
||||
@DATA_SAMPLERS.register_module()
|
|
@ -11,7 +11,7 @@ from mmcv.transforms import BaseTransform, Compose, RandomChoice
|
|||
from mmcv.transforms.utils import cache_randomness
|
||||
from mmengine.utils import is_list_of, is_seq_of
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
|
||||
|
||||
def merge_hparams(policy: dict, hparams: dict) -> dict:
|
||||
|
@ -125,7 +125,7 @@ class RandAugment(BaseTransform):
|
|||
time, and magnitude_level of every policy is 6 (total is 10 by default)
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from mmcls.datasets import RandAugment
|
||||
>>> from mmpretrain.datasets import RandAugment
|
||||
>>> transform = RandAugment(
|
||||
... policies='timm_increasing',
|
||||
... num_policies=2,
|
|
@ -9,8 +9,8 @@ from mmcv.transforms import BaseTransform
|
|||
from mmengine.utils import is_str
|
||||
from PIL import Image
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from mmcls.structures import ClsDataSample, MultiTaskDataSample
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
|
||||
|
||||
|
||||
def to_tensor(data):
|
||||
|
@ -53,8 +53,8 @@ class PackClsInputs(BaseTransform):
|
|||
**Added Keys:**
|
||||
|
||||
- inputs (:obj:`torch.Tensor`): The forward data of models.
|
||||
- data_samples (:obj:`~mmcls.structures.ClsDataSample`): The annotation
|
||||
info of the sample.
|
||||
- data_samples (:obj:`~mmpretrain.structures.ClsDataSample`): The
|
||||
annotation info of the sample.
|
||||
|
||||
Args:
|
||||
meta_keys (Sequence[str]): The meta keys to be saved in the
|
|
@ -11,7 +11,7 @@ import numpy as np
|
|||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
|
||||
try:
|
||||
import albumentations
|
||||
|
@ -1008,7 +1008,7 @@ class Lighting(BaseTransform):
|
|||
return repr_str
|
||||
|
||||
|
||||
# 'Albu' is used in previous versions of mmcls, here is for compatibility
|
||||
# 'Albu' is used in previous versions of mmpretrain, here is for compatibility
|
||||
# users can use both 'Albumentations' and 'Albu'.
|
||||
@TRANSFORMS.register_module(['Albumentations', 'Albu'])
|
||||
class Albumentations(BaseTransform):
|
||||
|
@ -1055,13 +1055,13 @@ class Albumentations(BaseTransform):
|
|||
|
||||
Args:
|
||||
transforms (List[Dict]): List of albumentations transform configs.
|
||||
keymap (Optional[Dict]): Mapping of mmcls to albumentations fields,
|
||||
in format {'input key':'albumentation-style key'}. Defaults to
|
||||
None.
|
||||
keymap (Optional[Dict]): Mapping of mmpretrain to albumentations
|
||||
fields, in format {'input key':'albumentation-style key'}.
|
||||
Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> import mmcv
|
||||
>>> from mmcls.datasets import Albumentations
|
||||
>>> from mmpretrain.datasets import Albumentations
|
||||
>>> transforms = [
|
||||
... dict(
|
||||
... type='ShiftScaleRotate',
|
|
@ -4,7 +4,7 @@ from typing import List, Optional, Union
|
|||
|
||||
from mmengine import get_file_backend, list_from_file
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS
|
||||
from .base_dataset import expanduser
|
||||
from .categories import VOC2007_CATEGORIES
|
||||
from .multi_label import MultiLabelDataset
|
|
@ -2,7 +2,7 @@
|
|||
from mmengine.hooks import Hook
|
||||
from mmengine.utils import is_seq_of
|
||||
|
||||
from mmcls.registry import HOOKS
|
||||
from mmpretrain.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
|
@ -8,7 +8,7 @@ from mmengine.hooks import EMAHook as BaseEMAHook
|
|||
from mmengine.logging import MMLogger
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmcls.registry import HOOKS
|
||||
from mmpretrain.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
|
@ -3,8 +3,8 @@ import numpy as np
|
|||
from mmengine.hooks import Hook
|
||||
from mmengine.model import is_model_wrapper
|
||||
|
||||
from mmcls.models.heads import ArcFaceClsHead
|
||||
from mmcls.registry import HOOKS
|
||||
from mmpretrain.models.heads import ArcFaceClsHead
|
||||
from mmpretrain.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
|
@ -20,7 +20,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
|
|||
from torch.nn.modules.instancenorm import _InstanceNorm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmcls.registry import HOOKS
|
||||
from mmpretrain.registry import HOOKS
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
|
@ -4,8 +4,8 @@ import warnings
|
|||
from mmengine.hooks import Hook
|
||||
from mmengine.model import is_model_wrapper
|
||||
|
||||
from mmcls.models import BaseRetriever
|
||||
from mmcls.registry import HOOKS
|
||||
from mmpretrain.models import BaseRetriever
|
||||
from mmpretrain.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -27,5 +27,6 @@ class PrepareProtoBeforeValLoopHook(Hook):
|
|||
model.prepare_prototype()
|
||||
else:
|
||||
warnings.warn(
|
||||
'Only the `mmcls.models.retrievers.BaseRetriever` can execute '
|
||||
f'`PrepareRetrieverPrototypeHook`, but got `{type(model)}`')
|
||||
'Only the `mmpretrain.models.retrievers.BaseRetriever` '
|
||||
'can execute `PrepareRetrieverPrototypeHook`, but got '
|
||||
f'`{type(model)}`')
|
|
@ -6,8 +6,8 @@ from mmcv.transforms import Compose
|
|||
from mmengine.hooks import Hook
|
||||
from mmengine.model import is_model_wrapper
|
||||
|
||||
from mmcls.models.utils import RandomBatchAugment
|
||||
from mmcls.registry import HOOKS, MODEL_WRAPPERS, MODELS
|
||||
from mmpretrain.models.utils import RandomBatchAugment
|
||||
from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -25,9 +25,9 @@ class SwitchRecipeHook(Hook):
|
|||
train dataset. If not specified, keep the original settings.
|
||||
- ``batch_augments`` (dict | None, optional): The new batch
|
||||
augmentations of during training. See :mod:`Batch Augmentations
|
||||
<mmcls.models.utils.batch_augments>` for more details. If None,
|
||||
disable batch augmentations. If not specified, keep the original
|
||||
settings.
|
||||
<mmpretrain.models.utils.batch_augments>` for more details.
|
||||
If None, disable batch augmentations. If not specified, keep the
|
||||
original settings.
|
||||
- ``loss`` (dict, optional): The new loss module config. If not
|
||||
specified, keep the original settings.
|
||||
|
|
@ -8,8 +8,8 @@ from mmengine.hooks import Hook
|
|||
from mmengine.runner import EpochBasedTrainLoop, Runner
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmcls.registry import HOOKS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmpretrain.registry import HOOKS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -30,7 +30,7 @@ class VisualizationHook(Hook):
|
|||
in the testing process. If None, handle with the backends of the
|
||||
visualizer. Defaults to None.
|
||||
**kwargs: other keyword arguments of
|
||||
:meth:`mmcls.visualization.ClsVisualizer.add_datasample`.
|
||||
:meth:`mmpretrain.visualization.ClsVisualizer.add_datasample`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
|
@ -19,7 +19,7 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from mmcls.registry import OPTIMIZERS
|
||||
from mmpretrain.registry import OPTIMIZERS
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module()
|
|
@ -61,7 +61,7 @@ import math
|
|||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from mmcls.registry import OPTIMIZERS
|
||||
from mmpretrain.registry import OPTIMIZERS
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module()
|
|
@ -1,2 +1,3 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .functional import * # noqa: F401,F403
|
||||
from .metrics import * # noqa: F401,F403
|
|
@ -7,7 +7,7 @@ from mmengine.evaluator import BaseMetric
|
|||
from mmengine.logging import MMLogger
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
from mmpretrain.registry import METRICS
|
||||
from .single_label import _precision_recall_f1_support, to_tensor
|
||||
|
||||
|
||||
|
@ -77,7 +77,7 @@ class MultiLabelMetric(BaseMetric):
|
|||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.evaluation import MultiLabelMetric
|
||||
>>> from mmpretrain.evaluation import MultiLabelMetric
|
||||
>>> # ------ The Basic Usage for category indices labels -------
|
||||
>>> y_pred = [[0], [1], [0, 1], [3]]
|
||||
>>> y_true = [[0, 3], [0, 2], [1], [3]]
|
||||
|
@ -114,7 +114,7 @@ class MultiLabelMetric(BaseMetric):
|
|||
(tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8))
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_sampels = [
|
||||
... ClsDataSample().set_pred_score(pred).set_gt_score(gt)
|
||||
|
@ -466,7 +466,7 @@ class AveragePrecision(BaseMetric):
|
|||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.evaluation import AveragePrecision
|
||||
>>> from mmpretrain.evaluation import AveragePrecision
|
||||
>>> # --------- The Basic Usage for one-hot pred scores ---------
|
||||
>>> y_pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2],
|
||||
... [0.1, 0.2, 0.2, 0.1],
|
||||
|
@ -479,7 +479,7 @@ class AveragePrecision(BaseMetric):
|
|||
>>> AveragePrecision.calculate(y_pred, y_true)
|
||||
tensor(70.833)
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_pred_score(i).set_gt_score(j)
|
|
@ -3,7 +3,7 @@ from typing import Dict, Sequence
|
|||
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
from mmpretrain.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
|
@ -14,7 +14,7 @@ class MultiTasksMetric(BaseMetric):
|
|||
and the values is a list of the metric corresponds to this task
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.evaluation import MultiTasksMetric
|
||||
>>> from mmpretrain.evaluation import MultiTasksMetric
|
||||
# -------------------- The Basic Usage --------------------
|
||||
>>>task_metrics = {
|
||||
'task0': [dict(type='Accuracy', topk=(1, ))],
|
|
@ -8,7 +8,7 @@ from mmengine.evaluator import BaseMetric
|
|||
from mmengine.structures import LabelData
|
||||
from mmengine.utils import is_seq_of
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
from mmpretrain.registry import METRICS
|
||||
from .single_label import to_tensor
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ class RetrievalRecall(BaseMetric):
|
|||
Use in the code:
|
||||
|
||||
>>> import torch
|
||||
>>> from mmcls.evaluation import RetrievalRecall
|
||||
>>> from mmpretrain.evaluation import RetrievalRecall
|
||||
>>> # -------------------- The Basic Usage --------------------
|
||||
>>> y_pred = [[0], [1], [2], [3]]
|
||||
>>> y_true = [[0, 1], [2], [1], [0, 3]]
|
||||
|
@ -48,7 +48,7 @@ class RetrievalRecall(BaseMetric):
|
|||
[tensor(9.3000), tensor(48.4000)]
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_gt_label([0, 1]).set_pred_score(
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
from mmpretrain.registry import METRICS
|
||||
|
||||
|
||||
def to_tensor(value):
|
||||
|
@ -91,7 +91,7 @@ class Accuracy(BaseMetric):
|
|||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.evaluation import Accuracy
|
||||
>>> from mmpretrain.evaluation import Accuracy
|
||||
>>> # -------------------- The Basic Usage --------------------
|
||||
>>> y_pred = [0, 2, 1, 3]
|
||||
>>> y_true = [0, 1, 2, 3]
|
||||
|
@ -104,7 +104,7 @@ class Accuracy(BaseMetric):
|
|||
[[tensor([9.9000])], [tensor([51.5000])]]
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_gt_label(0).set_pred_score(torch.rand(10))
|
||||
|
@ -343,7 +343,7 @@ class SingleLabelMetric(BaseMetric):
|
|||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.evaluation import SingleLabelMetric
|
||||
>>> from mmpretrain.evaluation import SingleLabelMetric
|
||||
>>> # -------------------- The Basic Usage --------------------
|
||||
>>> y_pred = [0, 1, 1, 3]
|
||||
>>> y_true = [0, 2, 1, 3]
|
||||
|
@ -358,7 +358,7 @@ class SingleLabelMetric(BaseMetric):
|
|||
(tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
|
||||
|
@ -606,7 +606,7 @@ class ConfusionMatrix(BaseMetric):
|
|||
1. The basic usage.
|
||||
|
||||
>>> import torch
|
||||
>>> from mmcls.evaluation import ConfusionMatrix
|
||||
>>> from mmpretrain.evaluation import ConfusionMatrix
|
||||
>>> y_pred = [0, 1, 1, 3]
|
||||
>>> y_true = [0, 2, 1, 3]
|
||||
>>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4)
|
|
@ -3,7 +3,7 @@ from typing import Optional, Sequence
|
|||
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
from mmpretrain.registry import METRICS
|
||||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
|
||||
|
|
@ -8,6 +8,8 @@ from .heads import * # noqa: F401,F403
|
|||
from .losses import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .retrievers import * # noqa: F401,F403
|
||||
from .selfsup import * # noqa: F401,F403
|
||||
from .target_generators import * # noqa: F401,F403
|
||||
from .tta import * # noqa: F401,F403
|
||||
from .utils import * # noqa: F401,F403
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ from mmcv.cnn.bricks.drop import build_dropout
|
|||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import (BEiTAttention, resize_pos_embed,
|
||||
resize_relative_position_bias_table, to_2tuple)
|
||||
from .vision_transformer import TransformerEncoderLayer, VisionTransformer
|
|
@ -10,7 +10,7 @@ from mmcv.cnn.bricks.transformer import AdaptivePadding
|
|||
from mmengine.model import BaseModule
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
from .vision_transformer import TransformerEncoderLayer
|
||||
|
|
@ -7,7 +7,7 @@ from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
|
|||
build_norm_layer)
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ import torch.utils.checkpoint as cp
|
|||
from mmcv.cnn.bricks import DropPath
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import GRN, build_norm_layer
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
@ -9,7 +9,7 @@ from mmcv.cnn.bricks import DropPath
|
|||
from mmengine.model import BaseModule, Sequential
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import to_ntuple
|
||||
from .resnet import Bottleneck as ResNetBottleneck
|
||||
from .resnext import Bottleneck as ResNeXtBottleneck
|
||||
|
@ -278,8 +278,8 @@ class CSPNet(BaseModule):
|
|||
>>> from functools import partial
|
||||
>>> import torch
|
||||
>>> import torch.nn as nn
|
||||
>>> from mmcls.models import CSPNet
|
||||
>>> from mmcls.models.backbones.resnet import Bottleneck
|
||||
>>> from mmpretrain.models import CSPNet
|
||||
>>> from mmpretrain.models.backbones.resnet import Bottleneck
|
||||
>>>
|
||||
>>> # A simple example to build CSPNet.
|
||||
>>> arch = dict(
|
||||
|
@ -427,7 +427,7 @@ class CSPDarkNet(CSPNet):
|
|||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> from mmcls.models import CSPDarkNet
|
||||
>>> from mmpretrain.models import CSPDarkNet
|
||||
>>> import torch
|
||||
>>> model = CSPDarkNet(depth=53, out_indices=(0, 1, 2, 3, 4))
|
||||
>>> model.eval()
|
||||
|
@ -523,7 +523,7 @@ class CSPResNet(CSPNet):
|
|||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Example:
|
||||
>>> from mmcls.models import CSPResNet
|
||||
>>> from mmpretrain.models import CSPResNet
|
||||
>>> import torch
|
||||
>>> model = CSPResNet(depth=50, out_indices=(0, 1, 2, 3))
|
||||
>>> model.eval()
|
||||
|
@ -645,7 +645,7 @@ class CSPResNeXt(CSPResNet):
|
|||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Example:
|
||||
>>> from mmcls.models import CSPResNeXt
|
||||
>>> from mmpretrain.models import CSPResNeXt
|
||||
>>> import torch
|
||||
>>> model = CSPResNeXt(depth=50, out_indices=(0, 1, 2, 3))
|
||||
>>> model.eval()
|
|
@ -12,8 +12,8 @@ from mmengine.model import BaseModule, ModuleList
|
|||
from mmengine.utils import to_2tuple
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.models.backbones.base_backbone import BaseBackbone
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import ShiftWindowMSA
|
||||
|
||||
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
||||
|
|
@ -10,7 +10,7 @@ from mmengine.model import BaseModule, ModuleList, Sequential
|
|||
from mmengine.utils import deprecated_api_warning
|
||||
from torch import nn
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import LayerScale, MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
|
@ -10,7 +10,7 @@ import torch.utils.checkpoint as cp
|
|||
from mmcv.cnn.bricks import build_activation_layer, build_norm_layer
|
||||
from torch.jit.annotations import List
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
from mmcv.cnn.bricks import DropPath
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import (ChannelMultiheadAttention, PositionEncodingFourier,
|
||||
build_norm_layer)
|
||||
from .base_backbone import BaseBackbone
|
|
@ -8,7 +8,7 @@ from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
|
|||
build_norm_layer)
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import LayerScale
|
||||
from .base_backbone import BaseBackbone
|
||||
from .poolformer import Pooling
|
||||
|
@ -376,7 +376,7 @@ class EfficientFormer(BaseBackbone):
|
|||
Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> from mmcls.models import EfficientFormer
|
||||
>>> from mmpretrain.models import EfficientFormer
|
||||
>>> import torch
|
||||
>>> inputs = torch.rand((1, 3, 224, 224))
|
||||
>>> # build EfficientFormer backbone for classification task
|
|
@ -9,9 +9,9 @@ import torch.utils.checkpoint as cp
|
|||
from mmcv.cnn.bricks import ConvModule, DropPath
|
||||
from mmengine.model import BaseModule, Sequential
|
||||
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.models.utils import InvertedResidual, SELayer, make_divisible
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.models.backbones.base_backbone import BaseBackbone
|
||||
from mmpretrain.models.utils import InvertedResidual, SELayer, make_divisible
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
class EdgeResidual(BaseModule):
|
|
@ -7,10 +7,10 @@ from mmcv.cnn.bricks import ConvModule, DropPath
|
|||
from mmengine.model import Sequential
|
||||
from torch import Tensor
|
||||
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.models.backbones.efficientnet import EdgeResidual as FusedMBConv
|
||||
from mmcls.models.utils import InvertedResidual as MBConv
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import InvertedResidual as MBConv
|
||||
from .base_backbone import BaseBackbone
|
||||
from .efficientnet import EdgeResidual as FusedMBConv
|
||||
|
||||
|
||||
class EnhancedConvModule(ConvModule):
|
|
@ -16,8 +16,8 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint as checkpoint
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.models.backbones.base_backbone import BaseBackbone
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import LayerScale
|
||||
|
||||
|
|
@ -4,7 +4,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
|
|||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion
|
||||
|
||||
|
||||
|
@ -232,7 +232,7 @@ class HRNet(BaseModule):
|
|||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> from mmcls.models import HRNet
|
||||
>>> from mmpretrain.models import HRNet
|
||||
>>> extra = dict(
|
||||
>>> stage1=dict(
|
||||
>>> num_modules=1,
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||
from mmcv.cnn import build_conv_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
|
@ -389,7 +389,7 @@ class InceptionV3(BaseBackbone):
|
|||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> from mmcls.models import build_backbone
|
||||
>>> from mmpretrain.models import build_backbone
|
||||
>>>
|
||||
>>> inputs = torch.rand(2, 3, 299, 299)
|
||||
>>> cfg = dict(type='InceptionV3', num_classes=100)
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
|
@ -7,8 +7,8 @@ from mmcv.cnn import build_activation_layer, fuse_conv_bn
|
|||
from mmcv.cnn.bricks import DropPath
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.models.backbones.base_backbone import BaseBackbone
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import build_norm_layer
|
||||
|
||||
|
|
@ -9,11 +9,10 @@ from mmengine.model import BaseModule
|
|||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
|
||||
from mmcls.models.utils.attention import WindowMSA
|
||||
from mmcls.models.utils.helpers import to_2tuple
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import WindowMSA, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
from .vision_transformer import TransformerEncoderLayer
|
||||
|
||||
|
||||
class MixMIMWindowAttention(WindowMSA):
|
|
@ -6,7 +6,7 @@ from mmcv.cnn import build_norm_layer
|
|||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
@ -5,8 +5,8 @@ from mmcv.cnn import ConvModule
|
|||
from mmengine.model import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import make_divisible
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.models.utils import make_divisible
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
from mmcv.cnn import ConvModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import InvertedResidual
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
@ -9,7 +9,7 @@ from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
|||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils.se_layer import SELayer
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
@ -302,7 +302,7 @@ class MobileOne(BaseBackbone):
|
|||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
|
||||
Example:
|
||||
>>> from mmcls.models import MobileOne
|
||||
>>> from mmpretrain.models import MobileOne
|
||||
>>> import torch
|
||||
>>> x = torch.rand(1, 3, 224, 224)
|
||||
>>> model = MobileOne("s0", out_indices=(0, 1, 2, 3))
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from torch import nn
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
from .mobilenet_v2 import InvertedResidual
|
||||
from .vision_transformer import TransformerEncoderLayer
|
|
@ -489,7 +489,7 @@ class MViT(BaseBackbone):
|
|||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.models import build_backbone
|
||||
>>> from mmpretrain.models import build_backbone
|
||||
>>>
|
||||
>>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3])
|
||||
>>> model = build_backbone(cfg)
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .resnet import ResNet
|
||||
from .resnext import Bottleneck
|
||||
|
||||
|
@ -43,7 +43,7 @@ class RegNet(ResNet):
|
|||
in resblocks to let them behave as identity. Default: True.
|
||||
|
||||
Example:
|
||||
>>> from mmcls.models import RegNet
|
||||
>>> from mmpretrain.models import RegNet
|
||||
>>> import torch
|
||||
>>> self = RegNet(
|
||||
arch=dict(
|
|
@ -7,7 +7,7 @@ from mmcv.cnn.bricks import DropPath
|
|||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
|
@ -8,8 +8,8 @@ from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
|
|||
from mmcv.cnn.bricks.transformer import PatchEmbed as _PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmcls.models.utils import SELayer, to_2tuple
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.models.utils import SELayer, to_2tuple
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
def fuse_bn(conv_or_fc, bn):
|
||||
|
@ -88,7 +88,7 @@ class PatchEmbed(_PatchEmbed):
|
|||
|
||||
|
||||
class GlobalPerceptron(SELayer):
|
||||
"""GlobalPerceptron implemented by using ``mmcls.modes.SELayer``.
|
||||
"""GlobalPerceptron implemented by using ``mmpretrain.modes.SELayer``.
|
||||
|
||||
Args:
|
||||
input_channels (int): The number of input (and output) channels
|
|
@ -8,7 +8,7 @@ from mmengine.model import BaseModule, Sequential
|
|||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
from torch import nn
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils.se_layer import SELayer
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
@ -7,7 +7,7 @@ import torch.utils.checkpoint as cp
|
|||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmengine.model import ModuleList, Sequential
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNet
|
||||
|
||||
|
@ -258,7 +258,7 @@ class Res2Net(ResNet):
|
|||
Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> from mmcls.models import Res2Net
|
||||
>>> from mmpretrain.models import Res2Net
|
||||
>>> import torch
|
||||
>>> model = Res2Net(depth=50,
|
||||
... scales=4,
|
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResLayer, ResNetV1d
|
||||
|
|
@ -9,7 +9,7 @@ from mmengine.model import BaseModule
|
|||
from mmengine.model.weight_init import constant_init
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
eps = 1.0e-5
|
||||
|
@ -438,7 +438,7 @@ class ResNet(BaseBackbone):
|
|||
in resblocks to let them behave as identity. Default: True.
|
||||
|
||||
Example:
|
||||
>>> from mmcls.models import ResNet
|
||||
>>> from mmpretrain.models import ResNet
|
||||
>>> import torch
|
||||
>>> self = ResNet(depth=18)
|
||||
>>> self.eval()
|
|
@ -2,7 +2,7 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .resnet import ResNet
|
||||
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmpretrain.registry import MODELS
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResLayer, ResNet
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue