Rename the package name to `mmpretrain`.

pull/1371/head
mzr1996 2023-02-17 11:31:08 +08:00
parent 8352951f3d
commit 0979e78573
324 changed files with 721 additions and 691 deletions

View File

@ -17,10 +17,10 @@ from modelindex.load_model_index import load
from rich.console import Console from rich.console import Console
from rich.table import Table from rich.table import Table
from mmcls.apis import init_model from mmpretrain.apis import init_model
from mmcls.datasets import CIFAR10, CIFAR100, ImageNet from mmpretrain.datasets import CIFAR10, CIFAR100, ImageNet
from mmcls.utils import register_all_modules from mmpretrain.utils import register_all_modules
from mmcls.visualization import ClsVisualizer from mmpretrain.visualization import ClsVisualizer
console = Console() console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2] MMCLS_ROOT = Path(__file__).absolute().parents[2]

View File

@ -18,9 +18,9 @@ from modelindex.load_model_index import load
from rich.console import Console from rich.console import Console
from rich.table import Table from rich.table import Table
from mmcls.datasets.builder import build_dataloader from mmpretrain.datasets.builder import build_dataloader
from mmcls.datasets.pipelines import Compose from mmpretrain.datasets.pipelines import Compose
from mmcls.models.builder import build_classifier from mmpretrain.models.builder import build_classifier
console = Console() console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2] MMCLS_ROOT = Path(__file__).absolute().parents[2]

View File

@ -47,7 +47,7 @@ def ckpt_to_state_dict(checkpoint, key=None):
if key is not None: if key is not None:
state_dict = checkpoint[key] state_dict = checkpoint[key]
elif 'state_dict' in checkpoint: elif 'state_dict' in checkpoint:
# try mmcls style # try mmpretrain style
state_dict = checkpoint['state_dict'] state_dict = checkpoint['state_dict']
elif 'model' in checkpoint: elif 'model' in checkpoint:
state_dict = checkpoint['model'] state_dict = checkpoint['model']
@ -149,7 +149,7 @@ def main():
if args.path.suffix in ['.json', '.py', '.yml']: if args.path.suffix in ['.json', '.py', '.yml']:
from mmengine.runner import get_state_dict from mmengine.runner import get_state_dict
from mmcls.apis import init_model from mmpretrain.apis import init_model
model = init_model(args.path, device='cpu') model = init_model(args.path, device='cpu')
state_dict = get_state_dict(model) state_dict = get_state_dict(model)
else: else:

View File

@ -55,7 +55,7 @@ def state_dict_from_cfg_or_ckpt(path, state_key=None):
if path.suffix in ['.json', '.py', '.yml']: if path.suffix in ['.json', '.py', '.yml']:
from mmengine.runner import get_state_dict from mmengine.runner import get_state_dict
from mmcls.apis import init_model from mmpretrain.apis import init_model
model = init_model(path, device='cpu') model = init_model(path, device='cpu')
model.init_weights() model.init_weights()
return get_state_dict(model) return get_state_dict(model)

View File

@ -84,8 +84,8 @@ def get_flops(config_path):
from mmengine.dataset import Compose from mmengine.dataset import Compose
from mmengine.registry import DefaultScope from mmengine.registry import DefaultScope
import mmcls.datasets # noqa: F401 import mmpretrain.datasets # noqa: F401
from mmcls.apis import init_model from mmpretrain.apis import init_model
cfg = Config.fromfile(config_path) cfg = Config.fromfile(config_path)
@ -98,7 +98,7 @@ def get_flops(config_path):
# The image shape of CIFAR is (32, 32, 3) # The image shape of CIFAR is (32, 32, 3)
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32)) test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
with DefaultScope.overwrite_default_scope('mmcls'): with DefaultScope.overwrite_default_scope('mmpretrain'):
data = Compose(test_dataset.pipeline)({ data = Compose(test_dataset.pipeline)({
'img': 'img':
np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8) np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)

2
.gitignore vendored
View File

@ -127,7 +127,7 @@ venv.bak/
/work_dirs /work_dirs
/projects/*/work_dirs /projects/*/work_dirs
/projects/*/data /projects/*/data
/mmcls/.mim /mmpretrain/.mim
.DS_Store .DS_Store
# Pytorch # Pytorch

View File

@ -1,4 +1,4 @@
include requirements/*.txt include requirements/*.txt
include mmcls/.mim/model-index.yml include mmpretrain/.mim/model-index.yml
recursive-include mmcls/.mim/configs *.py *.yml recursive-include mmpretrain/.mim/configs *.py *.yml
recursive-include mmcls/.mim/tools *.py *.sh recursive-include mmpretrain/.mim/tools *.py *.sh

View File

@ -1,5 +1,5 @@
# defaults to use registries in mmcls # defaults to use registries in mmpretrain
default_scope = 'mmcls' default_scope = 'mmpretrain'
# configure default hooks # configure default hooks
default_hooks = dict( default_hooks = dict(

View File

@ -1,4 +1,4 @@
from mmcls.models import build_classifier from mmpretrain.models import build_classifier
model = dict( model = dict(
type='ImageClassifier', type='ImageClassifier',

View File

@ -2,11 +2,11 @@ _base_ = ['../_base_/datasets/voc_bs16.py', '../_base_/default_runtime.py']
# Pre-trained Checkpoint Path # Pre-trained Checkpoint Path
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth' # noqa checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth' # noqa
# If you want to use the pre-trained weight of ResNet101-CutMix from # If you want to use the pre-trained weight of ResNet101-CutMix from the
# the originary repo(https://github.com/Kevinz-code/CSRA). Script of # originary repo(https://github.com/Kevinz-code/CSRA). Script of
# 'tools/convert_models/torchvision_to_mmcls.py' can help you convert weight # 'tools/model_converters/torchvision_to_mmpretrain.py' can help you convert
# into mmcls format. The mAP result would hit 95.5 by using the weight. # weight into mmpretrain format. The mAP result would hit 95.5 by using the
# checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT' # weight. checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT'
# model settings # model settings
model = dict( model = dict(

View File

@ -51,7 +51,7 @@ val_cfg = dict()
test_cfg = dict() test_cfg = dict()
# runtime settings # runtime settings
default_scope = 'mmcls' default_scope = 'mmpretrain'
default_hooks = dict( default_hooks = dict(
# record the time of every iteration. # record the time of every iteration.

View File

@ -4,7 +4,7 @@ from argparse import ArgumentParser
from mmengine.fileio import dump from mmengine.fileio import dump
from rich import print_json from rich import print_json
from mmcls.apis import ImageClassificationInferencer from mmpretrain.apis import ImageClassificationInferencer
def main(): def main():
@ -30,7 +30,7 @@ def main():
raise ValueError( raise ValueError(
f'Unavailable model "{args.model}", you can specify find a model ' f'Unavailable model "{args.model}", you can specify find a model '
'name or a config file or find a model name from ' 'name or a config file or find a model name from '
'https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html#all-checkpoints' # noqa: E501 'https://mmpretrain.readthedocs.io/en/1.x/modelzoo_statistics.html#all-checkpoints' # noqa: E501
) )
result = inferencer(args.img, show=args.show, show_dir=args.show_dir)[0] result = inferencer(args.img, show=args.show, show_dir=args.show_dir)[0]
# show the results # show the results

View File

@ -22,12 +22,12 @@ sys.path.insert(0, os.path.abspath('../../'))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = 'MMClassification' project = 'MMPretrain'
copyright = '2020, OpenMMLab' copyright = '2020, OpenMMLab'
author = 'MMClassification Authors' author = 'MMPretrain Authors'
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
version_file = '../../mmcls/version.py' version_file = '../../mmpretrain/version.py'
def get_version(): def get_version():
@ -92,25 +92,25 @@ html_theme_options = {
'menu': [ 'menu': [
{ {
'name': 'GitHub', 'name': 'GitHub',
'url': 'https://github.com/open-mmlab/mmclassification' 'url': 'https://github.com/open-mmlab/mmpretrain'
}, },
{ {
'name': 'Colab Tutorials', 'name': 'Colab Tutorials',
'children': [ 'children': [
{'name': 'Train and inference with shell commands', {'name': 'Train and inference with shell commands',
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_tools.ipynb'}, 'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_tools.ipynb'},
{'name': 'Train and inference with Python APIs', {'name': 'Train and inference with Python APIs',
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_python.ipynb'}, 'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_python.ipynb'},
] ]
}, },
{ {
'name': 'Version', 'name': 'Version',
'children': [ 'children': [
{'name': 'MMClassification 0.x', {'name': 'MMPretrain 0.x',
'url': 'https://mmclassification.readthedocs.io/en/latest/', 'url': 'https://mmpretrain.readthedocs.io/en/latest/',
'description': 'master branch'}, 'description': 'master branch'},
{'name': 'MMClassification 1.x', {'name': 'MMPretrain 1.x',
'url': 'https://mmclassification.readthedocs.io/en/dev-1.x/', 'url': 'https://mmpretrain.readthedocs.io/en/dev-1.x/',
'description': '1.x branch'}, 'description': '1.x branch'},
], ],
} }
@ -138,7 +138,7 @@ html_js_files = [
# -- Options for HTMLHelp output --------------------------------------------- # -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'mmclsdoc' htmlhelp_basename = 'mmpretraindoc'
# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------
@ -160,16 +160,14 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(root_doc, 'mmcls.tex', 'MMClassification Documentation', author, (root_doc, 'mmpretrain.tex', 'MMPretrain Documentation', author, 'manual'),
'manual'),
] ]
# -- Options for manual page output ------------------------------------------ # -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1) man_pages = [(root_doc, 'mmpretrain', 'MMPretrain Documentation', [author], 1)]
]
# -- Options for Texinfo output ---------------------------------------------- # -- Options for Texinfo output ----------------------------------------------
@ -177,7 +175,7 @@ man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1)
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(root_doc, 'mmcls', 'MMClassification Documentation', author, 'mmcls', (root_doc, 'mmpretrain', 'MMPretrain Documentation', author, 'mmpretrain',
'OpenMMLab image classification toolbox and benchmark.', 'Miscellaneous'), 'OpenMMLab image classification toolbox and benchmark.', 'Miscellaneous'),
] ]

View File

@ -9,7 +9,7 @@ from tabulate import tabulate
MMCLS_ROOT = Path(__file__).absolute().parents[2] MMCLS_ROOT = Path(__file__).absolute().parents[2]
PAPERS_ROOT = Path('papers') # Path to save generated paper pages. PAPERS_ROOT = Path('papers') # Path to save generated paper pages.
GITHUB_PREFIX = 'https://github.com/open-mmlab/mmclassification/blob/1.x/' GITHUB_PREFIX = 'https://github.com/open-mmlab/mmpretrain/blob/1.x/'
MODELZOO_TEMPLATE = """ MODELZOO_TEMPLATE = """
# Model Zoo Summary # Model Zoo Summary

View File

@ -22,12 +22,12 @@ sys.path.insert(0, os.path.abspath('../..'))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = 'MMClassification' project = 'MMPretrain'
copyright = '2020, OpenMMLab' copyright = '2020, OpenMMLab'
author = 'MMClassification Authors' author = 'MMPretrain Authors'
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
version_file = '../../mmcls/version.py' version_file = '../../mmpretrain/version.py'
def get_version(): def get_version():
@ -92,25 +92,25 @@ html_theme_options = {
'menu': [ 'menu': [
{ {
'name': 'GitHub', 'name': 'GitHub',
'url': 'https://github.com/open-mmlab/mmclassification' 'url': 'https://github.com/open-mmlab/mmpretrain'
}, },
{ {
'name': 'Colab 教程', 'name': 'Colab 教程',
'children': [ 'children': [
{'name': '用命令行工具训练和推理', {'name': '用命令行工具训练和推理',
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_tools.ipynb'}, 'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_tools.ipynb'},
{'name': '用 Python API 训练和推理', {'name': '用 Python API 训练和推理',
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_python.ipynb'}, 'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_python.ipynb'},
] ]
}, },
{ {
'name': 'Version', 'name': 'Version',
'children': [ 'children': [
{'name': 'MMClassification 0.x', {'name': 'MMPretrain 0.x',
'url': 'https://mmclassification.readthedocs.io/zh_CN/latest/', 'url': 'https://mmpretrain.readthedocs.io/zh_CN/latest/',
'description': 'master branch'}, 'description': 'master branch'},
{'name': 'MMClassification 1.x', {'name': 'MMPretrain 1.x',
'url': 'https://mmclassification.readthedocs.io/zh_CN/dev-1.x/', 'url': 'https://mmpretrain.readthedocs.io/zh_CN/dev-1.x/',
'description': '1.x branch'}, 'description': '1.x branch'},
], ],
} }
@ -138,7 +138,7 @@ html_js_files = [
# -- Options for HTMLHelp output --------------------------------------------- # -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'mmclsdoc' htmlhelp_basename = 'mmpretraindoc'
# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------
@ -164,16 +164,14 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(root_doc, 'mmcls.tex', 'MMClassification Documentation', author, (root_doc, 'mmpretrain.tex', 'MMPretrain Documentation', author, 'manual'),
'manual'),
] ]
# -- Options for manual page output ------------------------------------------ # -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1) man_pages = [(root_doc, 'mmpretrain', 'MMPretrain Documentation', [author], 1)]
]
# -- Options for Texinfo output ---------------------------------------------- # -- Options for Texinfo output ----------------------------------------------
@ -181,7 +179,7 @@ man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1)
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(root_doc, 'mmcls', 'MMClassification Documentation', author, 'mmcls', (root_doc, 'mmpretrain', 'MMPretrain Documentation', author, 'mmpretrain',
'OpenMMLab image classification toolbox and benchmark.', 'Miscellaneous'), 'OpenMMLab image classification toolbox and benchmark.', 'Miscellaneous'),
] ]

View File

@ -9,7 +9,7 @@ from tabulate import tabulate
MMCLS_ROOT = Path(__file__).absolute().parents[2] MMCLS_ROOT = Path(__file__).absolute().parents[2]
PAPERS_ROOT = Path('papers') # Path to save generated paper pages. PAPERS_ROOT = Path('papers') # Path to save generated paper pages.
GITHUB_PREFIX = 'https://github.com/open-mmlab/mmclassification/blob/1.x/' GITHUB_PREFIX = 'https://github.com/open-mmlab/mmpretrain/blob/1.x/'
MODELZOO_TEMPLATE = """ MODELZOO_TEMPLATE = """
# 模型库统计 # 模型库统计

View File

@ -1,39 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import warnings
from mmengine import DefaultScope
def register_all_modules(init_default_scope: bool = True) -> None:
"""Register all modules in mmcls into the registries.
Args:
init_default_scope (bool): Whether initialize the mmcls default scope.
If True, the global default scope will be set to `mmcls`, and all
registries will build modules from mmcls's registry node. To
understand more about the registry, please refer to
https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
Defaults to True.
""" # noqa
import mmcls.datasets # noqa: F401,F403
import mmcls.engine # noqa: F401,F403
import mmcls.evaluation # noqa: F401,F403
import mmcls.models # noqa: F401,F403
import mmcls.structures # noqa: F401,F403
import mmcls.visualization # noqa: F401,F403
if not init_default_scope:
return
current_scope = DefaultScope.get_current_instance()
if current_scope is None:
DefaultScope.get_instance('mmcls', scope_name='mmcls')
elif current_scope.scope_name != 'mmcls':
warnings.warn(f'The current default scope "{current_scope.scope_name}"'
' is not "mmcls", `register_all_modules` will force the '
'current default scope to be "mmcls". If this is not '
'expected, please set `init_default_scope=False`.')
# avoid name conflict
new_instance_name = f'mmcls-{datetime.datetime.now()}'
DefaultScope.get_instance(new_instance_name, scope_name='mmcls')

View File

@ -11,8 +11,8 @@ from mmengine.infer import BaseInferencer
from mmengine.model import BaseModel from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint from mmengine.runner import load_checkpoint
from mmcls.registry import TRANSFORMS from mmpretrain.registry import TRANSFORMS
from mmcls.structures import ClsDataSample from mmpretrain.structures import ClsDataSample
from .model import get_model, init_model, list_models from .model import get_model, init_model, list_models
ModelType = Union[BaseModel, str, Config] ModelType = Union[BaseModel, str, Config]
@ -63,7 +63,7 @@ class ImageClassificationInferencer(BaseInferencer):
Example: Example:
1. Use a pre-trained model in MMClassification to inference an image. 1. Use a pre-trained model in MMClassification to inference an image.
>>> from mmcls import ImageClassificationInferencer >>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k') >>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
>>> inferencer('demo/demo.JPEG') >>> inferencer('demo/demo.JPEG')
[{'pred_score': array([...]), [{'pred_score': array([...]),
@ -74,7 +74,7 @@ class ImageClassificationInferencer(BaseInferencer):
2. Use a config file and checkpoint to inference multiple images on GPU, 2. Use a config file and checkpoint to inference multiple images on GPU,
and save the visualization results in a folder. and save the visualization results in a folder.
>>> from mmcls import ImageClassificationInferencer >>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer( >>> inferencer = ImageClassificationInferencer(
model='configs/resnet/resnet50_8xb32_in1k.py', model='configs/resnet/resnet50_8xb32_in1k.py',
weights='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth', weights='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
@ -107,7 +107,7 @@ class ImageClassificationInferencer(BaseInferencer):
else: else:
raise TypeError( raise TypeError(
'The `model` can be a name of model and you can use ' 'The `model` can be a name of model and you can use '
'`mmcls.list_models` to get an available name. It can ' '`mmpretrain.list_models` to get an available name. It can '
'also be a Config object or a path to the config file.') 'also be a Config object or a path to the config file.')
model.eval() model.eval()
@ -185,7 +185,7 @@ class ImageClassificationInferencer(BaseInferencer):
return None return None
if self.visualizer is None: if self.visualizer is None:
from mmcls.visualization import ClsVisualizer from mmpretrain.visualization import ClsVisualizer
self.visualizer = ClsVisualizer() self.visualizer = ClsVisualizer()
if self.classes is not None: if self.classes is not None:
self.visualizer._dataset_meta = dict(classes=self.classes) self.visualizer._dataset_meta = dict(classes=self.classes)

View File

@ -15,7 +15,7 @@ from modelindex.models.Model import Model
class ModelHub: class ModelHub:
"""A hub to host the meta information of all pre-defined models.""" """A hub to host the meta information of all pre-defined models."""
_models_dict = {} _models_dict = {}
__mmcls_registered = False __mmpretrain_registered = False
@classmethod @classmethod
def register_model_index(cls, def register_model_index(cls,
@ -52,7 +52,7 @@ class ModelHub:
Returns: Returns:
modelindex.models.Model: The metainfo of the specified model. modelindex.models.Model: The metainfo of the specified model.
""" """
cls._register_mmcls_models() cls._register_mmpretrain_models()
# lazy load config # lazy load config
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower())) metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
if metainfo is None: if metainfo is None:
@ -75,15 +75,15 @@ class ModelHub:
return config_path return config_path
@classmethod @classmethod
def _register_mmcls_models(cls): def _register_mmpretrain_models(cls):
# register models in mmcls # register models in mmpretrain
if not cls.__mmcls_registered: if not cls.__mmpretrain_registered:
from mmengine.utils import get_installed_path from mmengine.utils import get_installed_path
mmcls_root = Path(get_installed_path('mmcls')) mmpretrain_root = Path(get_installed_path('mmpretrain'))
model_index_path = mmcls_root / '.mim' / 'model-index.yml' model_index_path = mmpretrain_root / '.mim' / 'model-index.yml'
ModelHub.register_model_index( ModelHub.register_model_index(
model_index_path, config_prefix=mmcls_root / '.mim') model_index_path, config_prefix=mmpretrain_root / '.mim')
cls.__mmcls_registered = True cls.__mmpretrain_registered = True
@classmethod @classmethod
def has(cls, model_name): def has(cls, model_name):
@ -118,7 +118,7 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
config.model.setdefault('data_preprocessor', config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None)) config.get('data_preprocessor', None))
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
model = MODELS.build(config.model) model = MODELS.build(config.model)
if checkpoint is not None: if checkpoint is not None:
@ -130,13 +130,13 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
# Don't set CLASSES if the model is headless. # Don't set CLASSES if the model is headless.
pass pass
elif 'dataset_meta' in checkpoint.get('meta', {}): elif 'dataset_meta' in checkpoint.get('meta', {}):
# mmcls 1.x # mmpretrain 1.x
model.CLASSES = checkpoint['meta']['dataset_meta'].get('classes') model.CLASSES = checkpoint['meta']['dataset_meta'].get('classes')
elif 'CLASSES' in checkpoint.get('meta', {}): elif 'CLASSES' in checkpoint.get('meta', {}):
# mmcls < 1.x # mmpretrain < 1.x or mmselfsup < 1.x
model.CLASSES = checkpoint['meta']['CLASSES'] model.CLASSES = checkpoint['meta']['CLASSES']
else: else:
from mmcls.datasets.categories import IMAGENET_CATEGORIES from mmpretrain.datasets.categories import IMAGENET_CATEGORIES
warnings.simplefilter('once') warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s ' warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.') 'meta data, use imagenet by default.')
@ -165,7 +165,7 @@ def get_model(model_name, pretrained=False, device=None, **kwargs):
Get a ResNet-50 model and extract images feature: Get a ResNet-50 model and extract images feature:
>>> import torch >>> import torch
>>> from mmcls import get_model >>> from mmpretrain import get_model
>>> inputs = torch.rand(16, 3, 224, 224) >>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs) >>> feats = model.extract_feat(inputs)
@ -178,7 +178,7 @@ def get_model(model_name, pretrained=False, device=None, **kwargs):
Get Swin-Transformer model with pre-trained weights and inference: Get Swin-Transformer model with pre-trained weights and inference:
>>> from mmcls import get_model, inference_model >>> from mmpretrain import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True) >>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG') >>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class']) >>> print(result['pred_class'])
@ -201,7 +201,7 @@ def get_model(model_name, pretrained=False, device=None, **kwargs):
def list_models(pattern=None) -> List[str]: def list_models(pattern=None) -> List[str]:
"""List all models available in MMClassification. """List all models available in MMPretrain.
Args: Args:
pattern (str | None): A wildcard pattern to match model names. pattern (str | None): A wildcard pattern to match model names.
@ -212,12 +212,12 @@ def list_models(pattern=None) -> List[str]:
Examples: Examples:
List all models: List all models:
>>> from mmcls import list_models >>> from mmpretrain import list_models
>>> print(list_models()) >>> print(list_models())
List ResNet-50 models on ImageNet-1k dataset: List ResNet-50 models on ImageNet-1k dataset:
>>> from mmcls import list_models >>> from mmpretrain import list_models
>>> print(list_models('resnet*in1k')) >>> print(list_models('resnet*in1k'))
['resnet50_8xb32_in1k', ['resnet50_8xb32_in1k',
'resnet50_8xb32-fp16_in1k', 'resnet50_8xb32-fp16_in1k',
@ -225,7 +225,7 @@ def list_models(pattern=None) -> List[str]:
'resnet50_8xb256-rsb-a2-300e_in1k', 'resnet50_8xb256-rsb-a2-300e_in1k',
'resnet50_8xb256-rsb-a3-100e_in1k'] 'resnet50_8xb256-rsb-a3-100e_in1k']
""" """
ModelHub._register_mmcls_models() ModelHub._register_mmpretrain_models()
if pattern is None: if pattern is None:
return sorted(list(ModelHub._models_dict.keys())) return sorted(list(ModelHub._models_dict.keys()))
# Always match keys with any postfix. # Always match keys with any postfix.

View File

@ -7,7 +7,7 @@ import mmengine
import numpy as np import numpy as np
from mmengine.dataset import BaseDataset as _BaseDataset from mmengine.dataset import BaseDataset as _BaseDataset
from mmcls.registry import DATASETS, TRANSFORMS from mmpretrain.registry import DATASETS, TRANSFORMS
def expanduser(path): def expanduser(path):

View File

@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcls.registry import DATASETS from mmpretrain.registry import DATASETS
def build_dataset(cfg): def build_dataset(cfg):
"""Build dataset. """Build dataset.
Examples: Examples:
>>> from mmcls.datasets import build_dataset >>> from mmpretrain.datasets import build_dataset
>>> mnist_train = build_dataset( >>> mnist_train = build_dataset(
... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False)) ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False))
>>> print(mnist_train) >>> print(mnist_train)

View File

@ -7,7 +7,7 @@ import numpy as np
from mmengine.fileio import (LocalBackend, exists, get, get_file_backend, from mmengine.fileio import (LocalBackend, exists, get, get_file_backend,
join_path) join_path)
from mmcls.registry import DATASETS from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset from .base_dataset import BaseDataset
from .categories import CIFAR10_CATEGORIES, CIFAR100_CATEGORIES from .categories import CIFAR10_CATEGORIES, CIFAR100_CATEGORIES
from .utils import check_md5, download_and_extract_archive from .utils import check_md5, download_and_extract_archive

View File

@ -3,7 +3,7 @@ from typing import List
from mmengine import get_file_backend, list_from_file from mmengine import get_file_backend, list_from_file
from mmcls.registry import DATASETS from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset from .base_dataset import BaseDataset
from .categories import CUB_CATEGORIES from .categories import CUB_CATEGORIES
@ -51,7 +51,7 @@ class CUB(BaseDataset):
Examples: Examples:
>>> from mmcls.datasets import CUB >>> from mmpretrain.datasets import CUB
>>> cub_train_cfg = dict(data_root='data/CUB_200_2011', test_mode=True) >>> cub_train_cfg = dict(data_root='data/CUB_200_2011', test_mode=True)
>>> cub_train = CUB(**cub_train_cfg) >>> cub_train = CUB(**cub_train_cfg)
>>> cub_train >>> cub_train

View File

@ -5,7 +5,7 @@ from mmengine.fileio import (BaseStorageBackend, get_file_backend,
list_from_file) list_from_file)
from mmengine.logging import MMLogger from mmengine.logging import MMLogger
from mmcls.registry import DATASETS from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset from .base_dataset import BaseDataset

View File

@ -4,7 +4,7 @@ import copy
import numpy as np import numpy as np
from mmengine.dataset import BaseDataset, force_full_init from mmengine.dataset import BaseDataset, force_full_init
from mmcls.registry import DATASETS from mmpretrain.registry import DATASETS
@DATASETS.register_module() @DATASETS.register_module()

View File

@ -3,7 +3,7 @@ from typing import Optional, Union
from mmengine.logging import MMLogger from mmengine.logging import MMLogger
from mmcls.registry import DATASETS from mmpretrain.registry import DATASETS
from .categories import IMAGENET_CATEGORIES from .categories import IMAGENET_CATEGORIES
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmengine import get_file_backend, list_from_file from mmengine import get_file_backend, list_from_file
from mmcls.datasets.base_dataset import BaseDataset from mmpretrain.registry import DATASETS
from mmcls.registry import DATASETS from .base_dataset import BaseDataset
@DATASETS.register_module() @DATASETS.register_module()
@ -35,7 +35,7 @@ class InShop(BaseDataset):
**kwargs: Other keyword arguments in :class:`BaseDataset`. **kwargs: Other keyword arguments in :class:`BaseDataset`.
Examples: Examples:
>>> from mmcls.datasets import InShop >>> from mmpretrain.datasets import InShop
>>> >>>
>>> # build train InShop dataset >>> # build train InShop dataset
>>> inshop_train_cfg = dict(data_root='data/inshop', split='train') >>> inshop_train_cfg = dict(data_root='data/inshop', split='train')

View File

@ -8,7 +8,7 @@ import numpy as np
import torch import torch
from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path
from mmcls.registry import DATASETS from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset from .base_dataset import BaseDataset
from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
from .utils import (download_and_extract_archive, open_maybe_compressed_file, from .utils import (download_and_extract_archive, open_maybe_compressed_file,

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List from typing import List
from mmcls.registry import DATASETS from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset from .base_dataset import BaseDataset

View File

@ -60,7 +60,7 @@ class MultiTaskDataset:
Assume we put our dataset in the ``data/mydataset`` folder in the Assume we put our dataset in the ``data/mydataset`` folder in the
repository and organize it as the below format: :: repository and organize it as the below format: ::
mmclassification/ mmpretrain/
data data
mydataset mydataset
annotation annotation
@ -81,7 +81,7 @@ class MultiTaskDataset:
.. code:: python .. code:: python
>>> from mmcls.datasets import build_dataset >>> from mmpretrain.datasets import build_dataset
>>> train_cfg = dict( >>> train_cfg = dict(
... type="MultiTaskDataset", ... type="MultiTaskDataset",
... ann_file="annotation/train.json", ... ann_file="annotation/train.json",
@ -94,7 +94,7 @@ class MultiTaskDataset:
Or we can put all files in the same folder: :: Or we can put all files in the same folder: ::
mmclassification/ mmpretrain/
data data
mydataset mydataset
train.json train.json
@ -109,7 +109,7 @@ class MultiTaskDataset:
.. code:: python .. code:: python
>>> from mmcls.datasets import build_dataset >>> from mmpretrain.datasets import build_dataset
>>> train_cfg = dict( >>> train_cfg = dict(
... type="MultiTaskDataset", ... type="MultiTaskDataset",
... ann_file="train.json", ... ann_file="train.json",
@ -133,8 +133,8 @@ class MultiTaskDataset:
``data_root`` for the ``"img_path"`` field in the annotation file. ``data_root`` for the ``"img_path"`` field in the annotation file.
Defaults to None. Defaults to None.
pipeline (Sequence[dict]): A list of dict, where each element pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`. represents a operation defined in
Defaults to an empty tuple. :mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple.
test_mode (bool): in train mode or test mode. Defaults to False. test_mode (bool): in train mode or test mode. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details. FileClient. See :class:`mmengine.fileio.FileClient` for details.

View File

@ -5,7 +5,7 @@ import torch
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
from torch.utils.data import Sampler from torch.utils.data import Sampler
from mmcls.registry import DATA_SAMPLERS from mmpretrain.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module() @DATA_SAMPLERS.register_module()

View File

@ -11,7 +11,7 @@ from mmcv.transforms import BaseTransform, Compose, RandomChoice
from mmcv.transforms.utils import cache_randomness from mmcv.transforms.utils import cache_randomness
from mmengine.utils import is_list_of, is_seq_of from mmengine.utils import is_list_of, is_seq_of
from mmcls.registry import TRANSFORMS from mmpretrain.registry import TRANSFORMS
def merge_hparams(policy: dict, hparams: dict) -> dict: def merge_hparams(policy: dict, hparams: dict) -> dict:
@ -125,7 +125,7 @@ class RandAugment(BaseTransform):
time, and magnitude_level of every policy is 6 (total is 10 by default) time, and magnitude_level of every policy is 6 (total is 10 by default)
>>> import numpy as np >>> import numpy as np
>>> from mmcls.datasets import RandAugment >>> from mmpretrain.datasets import RandAugment
>>> transform = RandAugment( >>> transform = RandAugment(
... policies='timm_increasing', ... policies='timm_increasing',
... num_policies=2, ... num_policies=2,

View File

@ -9,8 +9,8 @@ from mmcv.transforms import BaseTransform
from mmengine.utils import is_str from mmengine.utils import is_str
from PIL import Image from PIL import Image
from mmcls.registry import TRANSFORMS from mmpretrain.registry import TRANSFORMS
from mmcls.structures import ClsDataSample, MultiTaskDataSample from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
def to_tensor(data): def to_tensor(data):
@ -53,8 +53,8 @@ class PackClsInputs(BaseTransform):
**Added Keys:** **Added Keys:**
- inputs (:obj:`torch.Tensor`): The forward data of models. - inputs (:obj:`torch.Tensor`): The forward data of models.
- data_samples (:obj:`~mmcls.structures.ClsDataSample`): The annotation - data_samples (:obj:`~mmpretrain.structures.ClsDataSample`): The
info of the sample. annotation info of the sample.
Args: Args:
meta_keys (Sequence[str]): The meta keys to be saved in the meta_keys (Sequence[str]): The meta keys to be saved in the

View File

@ -11,7 +11,7 @@ import numpy as np
from mmcv.transforms import BaseTransform from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness from mmcv.transforms.utils import cache_randomness
from mmcls.registry import TRANSFORMS from mmpretrain.registry import TRANSFORMS
try: try:
import albumentations import albumentations
@ -1008,7 +1008,7 @@ class Lighting(BaseTransform):
return repr_str return repr_str
# 'Albu' is used in previous versions of mmcls, here is for compatibility # 'Albu' is used in previous versions of mmpretrain, here is for compatibility
# users can use both 'Albumentations' and 'Albu'. # users can use both 'Albumentations' and 'Albu'.
@TRANSFORMS.register_module(['Albumentations', 'Albu']) @TRANSFORMS.register_module(['Albumentations', 'Albu'])
class Albumentations(BaseTransform): class Albumentations(BaseTransform):
@ -1055,13 +1055,13 @@ class Albumentations(BaseTransform):
Args: Args:
transforms (List[Dict]): List of albumentations transform configs. transforms (List[Dict]): List of albumentations transform configs.
keymap (Optional[Dict]): Mapping of mmcls to albumentations fields, keymap (Optional[Dict]): Mapping of mmpretrain to albumentations
in format {'input key':'albumentation-style key'}. Defaults to fields, in format {'input key':'albumentation-style key'}.
None. Defaults to None.
Example: Example:
>>> import mmcv >>> import mmcv
>>> from mmcls.datasets import Albumentations >>> from mmpretrain.datasets import Albumentations
>>> transforms = [ >>> transforms = [
... dict( ... dict(
... type='ShiftScaleRotate', ... type='ShiftScaleRotate',

View File

@ -4,7 +4,7 @@ from typing import List, Optional, Union
from mmengine import get_file_backend, list_from_file from mmengine import get_file_backend, list_from_file
from mmcls.registry import DATASETS from mmpretrain.registry import DATASETS
from .base_dataset import expanduser from .base_dataset import expanduser
from .categories import VOC2007_CATEGORIES from .categories import VOC2007_CATEGORIES
from .multi_label import MultiLabelDataset from .multi_label import MultiLabelDataset

View File

@ -2,7 +2,7 @@
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.utils import is_seq_of from mmengine.utils import is_seq_of
from mmcls.registry import HOOKS from mmpretrain.registry import HOOKS
@HOOKS.register_module() @HOOKS.register_module()

View File

@ -8,7 +8,7 @@ from mmengine.hooks import EMAHook as BaseEMAHook
from mmengine.logging import MMLogger from mmengine.logging import MMLogger
from mmengine.runner import Runner from mmengine.runner import Runner
from mmcls.registry import HOOKS from mmpretrain.registry import HOOKS
@HOOKS.register_module() @HOOKS.register_module()

View File

@ -3,8 +3,8 @@ import numpy as np
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper from mmengine.model import is_model_wrapper
from mmcls.models.heads import ArcFaceClsHead from mmpretrain.models.heads import ArcFaceClsHead
from mmcls.registry import HOOKS from mmpretrain.registry import HOOKS
@HOOKS.register_module() @HOOKS.register_module()

View File

@ -20,7 +20,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm from torch.nn.modules.instancenorm import _InstanceNorm
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmcls.registry import HOOKS from mmpretrain.registry import HOOKS
DATA_BATCH = Optional[Sequence[dict]] DATA_BATCH = Optional[Sequence[dict]]

View File

@ -4,8 +4,8 @@ import warnings
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper from mmengine.model import is_model_wrapper
from mmcls.models import BaseRetriever from mmpretrain.models import BaseRetriever
from mmcls.registry import HOOKS from mmpretrain.registry import HOOKS
@HOOKS.register_module() @HOOKS.register_module()
@ -27,5 +27,6 @@ class PrepareProtoBeforeValLoopHook(Hook):
model.prepare_prototype() model.prepare_prototype()
else: else:
warnings.warn( warnings.warn(
'Only the `mmcls.models.retrievers.BaseRetriever` can execute ' 'Only the `mmpretrain.models.retrievers.BaseRetriever` '
f'`PrepareRetrieverPrototypeHook`, but got `{type(model)}`') 'can execute `PrepareRetrieverPrototypeHook`, but got '
f'`{type(model)}`')

View File

@ -6,8 +6,8 @@ from mmcv.transforms import Compose
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper from mmengine.model import is_model_wrapper
from mmcls.models.utils import RandomBatchAugment from mmpretrain.models.utils import RandomBatchAugment
from mmcls.registry import HOOKS, MODEL_WRAPPERS, MODELS from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS
@HOOKS.register_module() @HOOKS.register_module()
@ -25,9 +25,9 @@ class SwitchRecipeHook(Hook):
train dataset. If not specified, keep the original settings. train dataset. If not specified, keep the original settings.
- ``batch_augments`` (dict | None, optional): The new batch - ``batch_augments`` (dict | None, optional): The new batch
augmentations of during training. See :mod:`Batch Augmentations augmentations of during training. See :mod:`Batch Augmentations
<mmcls.models.utils.batch_augments>` for more details. If None, <mmpretrain.models.utils.batch_augments>` for more details.
disable batch augmentations. If not specified, keep the original If None, disable batch augmentations. If not specified, keep the
settings. original settings.
- ``loss`` (dict, optional): The new loss module config. If not - ``loss`` (dict, optional): The new loss module config. If not
specified, keep the original settings. specified, keep the original settings.

View File

@ -8,8 +8,8 @@ from mmengine.hooks import Hook
from mmengine.runner import EpochBasedTrainLoop, Runner from mmengine.runner import EpochBasedTrainLoop, Runner
from mmengine.visualization import Visualizer from mmengine.visualization import Visualizer
from mmcls.registry import HOOKS from mmpretrain.registry import HOOKS
from mmcls.structures import ClsDataSample from mmpretrain.structures import ClsDataSample
@HOOKS.register_module() @HOOKS.register_module()
@ -30,7 +30,7 @@ class VisualizationHook(Hook):
in the testing process. If None, handle with the backends of the in the testing process. If None, handle with the backends of the
visualizer. Defaults to None. visualizer. Defaults to None.
**kwargs: other keyword arguments of **kwargs: other keyword arguments of
:meth:`mmcls.visualization.ClsVisualizer.add_datasample`. :meth:`mmpretrain.visualization.ClsVisualizer.add_datasample`.
""" """
def __init__(self, def __init__(self,

View File

@ -19,7 +19,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from mmcls.registry import OPTIMIZERS from mmpretrain.registry import OPTIMIZERS
@OPTIMIZERS.register_module() @OPTIMIZERS.register_module()

View File

@ -61,7 +61,7 @@ import math
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from mmcls.registry import OPTIMIZERS from mmpretrain.registry import OPTIMIZERS
@OPTIMIZERS.register_module() @OPTIMIZERS.register_module()

View File

@ -1,2 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .functional import * # noqa: F401,F403
from .metrics import * # noqa: F401,F403 from .metrics import * # noqa: F401,F403

View File

@ -7,7 +7,7 @@ from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger from mmengine.logging import MMLogger
from mmengine.structures import LabelData from mmengine.structures import LabelData
from mmcls.registry import METRICS from mmpretrain.registry import METRICS
from .single_label import _precision_recall_f1_support, to_tensor from .single_label import _precision_recall_f1_support, to_tensor
@ -77,7 +77,7 @@ class MultiLabelMetric(BaseMetric):
Examples: Examples:
>>> import torch >>> import torch
>>> from mmcls.evaluation import MultiLabelMetric >>> from mmpretrain.evaluation import MultiLabelMetric
>>> # ------ The Basic Usage for category indices labels ------- >>> # ------ The Basic Usage for category indices labels -------
>>> y_pred = [[0], [1], [0, 1], [3]] >>> y_pred = [[0], [1], [0, 1], [3]]
>>> y_true = [[0, 3], [0, 2], [1], [3]] >>> y_true = [[0, 3], [0, 2], [1], [3]]
@ -114,7 +114,7 @@ class MultiLabelMetric(BaseMetric):
(tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8)) (tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8))
>>> >>>
>>> # ------------------- Use with Evalutor ------------------- >>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample >>> from mmpretrain.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator >>> from mmengine.evaluator import Evaluator
>>> data_sampels = [ >>> data_sampels = [
... ClsDataSample().set_pred_score(pred).set_gt_score(gt) ... ClsDataSample().set_pred_score(pred).set_gt_score(gt)
@ -466,7 +466,7 @@ class AveragePrecision(BaseMetric):
Examples: Examples:
>>> import torch >>> import torch
>>> from mmcls.evaluation import AveragePrecision >>> from mmpretrain.evaluation import AveragePrecision
>>> # --------- The Basic Usage for one-hot pred scores --------- >>> # --------- The Basic Usage for one-hot pred scores ---------
>>> y_pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2], >>> y_pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2],
... [0.1, 0.2, 0.2, 0.1], ... [0.1, 0.2, 0.2, 0.1],
@ -479,7 +479,7 @@ class AveragePrecision(BaseMetric):
>>> AveragePrecision.calculate(y_pred, y_true) >>> AveragePrecision.calculate(y_pred, y_true)
tensor(70.833) tensor(70.833)
>>> # ------------------- Use with Evalutor ------------------- >>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample >>> from mmpretrain.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator >>> from mmengine.evaluator import Evaluator
>>> data_samples = [ >>> data_samples = [
... ClsDataSample().set_pred_score(i).set_gt_score(j) ... ClsDataSample().set_pred_score(i).set_gt_score(j)

View File

@ -3,7 +3,7 @@ from typing import Dict, Sequence
from mmengine.evaluator import BaseMetric from mmengine.evaluator import BaseMetric
from mmcls.registry import METRICS from mmpretrain.registry import METRICS
@METRICS.register_module() @METRICS.register_module()
@ -14,7 +14,7 @@ class MultiTasksMetric(BaseMetric):
and the values is a list of the metric corresponds to this task and the values is a list of the metric corresponds to this task
Examples: Examples:
>>> import torch >>> import torch
>>> from mmcls.evaluation import MultiTasksMetric >>> from mmpretrain.evaluation import MultiTasksMetric
# -------------------- The Basic Usage -------------------- # -------------------- The Basic Usage --------------------
>>>task_metrics = { >>>task_metrics = {
'task0': [dict(type='Accuracy', topk=(1, ))], 'task0': [dict(type='Accuracy', topk=(1, ))],

View File

@ -8,7 +8,7 @@ from mmengine.evaluator import BaseMetric
from mmengine.structures import LabelData from mmengine.structures import LabelData
from mmengine.utils import is_seq_of from mmengine.utils import is_seq_of
from mmcls.registry import METRICS from mmpretrain.registry import METRICS
from .single_label import to_tensor from .single_label import to_tensor
@ -33,7 +33,7 @@ class RetrievalRecall(BaseMetric):
Use in the code: Use in the code:
>>> import torch >>> import torch
>>> from mmcls.evaluation import RetrievalRecall >>> from mmpretrain.evaluation import RetrievalRecall
>>> # -------------------- The Basic Usage -------------------- >>> # -------------------- The Basic Usage --------------------
>>> y_pred = [[0], [1], [2], [3]] >>> y_pred = [[0], [1], [2], [3]]
>>> y_true = [[0, 1], [2], [1], [0, 3]] >>> y_true = [[0, 1], [2], [1], [0, 3]]
@ -48,7 +48,7 @@ class RetrievalRecall(BaseMetric):
[tensor(9.3000), tensor(48.4000)] [tensor(9.3000), tensor(48.4000)]
>>> >>>
>>> # ------------------- Use with Evalutor ------------------- >>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample >>> from mmpretrain.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator >>> from mmengine.evaluator import Evaluator
>>> data_samples = [ >>> data_samples = [
... ClsDataSample().set_gt_label([0, 1]).set_pred_score( ... ClsDataSample().set_gt_label([0, 1]).set_pred_score(

View File

@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.evaluator import BaseMetric from mmengine.evaluator import BaseMetric
from mmcls.registry import METRICS from mmpretrain.registry import METRICS
def to_tensor(value): def to_tensor(value):
@ -91,7 +91,7 @@ class Accuracy(BaseMetric):
Examples: Examples:
>>> import torch >>> import torch
>>> from mmcls.evaluation import Accuracy >>> from mmpretrain.evaluation import Accuracy
>>> # -------------------- The Basic Usage -------------------- >>> # -------------------- The Basic Usage --------------------
>>> y_pred = [0, 2, 1, 3] >>> y_pred = [0, 2, 1, 3]
>>> y_true = [0, 1, 2, 3] >>> y_true = [0, 1, 2, 3]
@ -104,7 +104,7 @@ class Accuracy(BaseMetric):
[[tensor([9.9000])], [tensor([51.5000])]] [[tensor([9.9000])], [tensor([51.5000])]]
>>> >>>
>>> # ------------------- Use with Evalutor ------------------- >>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample >>> from mmpretrain.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator >>> from mmengine.evaluator import Evaluator
>>> data_samples = [ >>> data_samples = [
... ClsDataSample().set_gt_label(0).set_pred_score(torch.rand(10)) ... ClsDataSample().set_gt_label(0).set_pred_score(torch.rand(10))
@ -343,7 +343,7 @@ class SingleLabelMetric(BaseMetric):
Examples: Examples:
>>> import torch >>> import torch
>>> from mmcls.evaluation import SingleLabelMetric >>> from mmpretrain.evaluation import SingleLabelMetric
>>> # -------------------- The Basic Usage -------------------- >>> # -------------------- The Basic Usage --------------------
>>> y_pred = [0, 1, 1, 3] >>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3] >>> y_true = [0, 2, 1, 3]
@ -358,7 +358,7 @@ class SingleLabelMetric(BaseMetric):
(tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))] (tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
>>> >>>
>>> # ------------------- Use with Evalutor ------------------- >>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample >>> from mmpretrain.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator >>> from mmengine.evaluator import Evaluator
>>> data_samples = [ >>> data_samples = [
... ClsDataSample().set_gt_label(i%5).set_pred_score(torch.rand(5)) ... ClsDataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
@ -606,7 +606,7 @@ class ConfusionMatrix(BaseMetric):
1. The basic usage. 1. The basic usage.
>>> import torch >>> import torch
>>> from mmcls.evaluation import ConfusionMatrix >>> from mmpretrain.evaluation import ConfusionMatrix
>>> y_pred = [0, 1, 1, 3] >>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3] >>> y_true = [0, 2, 1, 3]
>>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4) >>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4)

View File

@ -3,7 +3,7 @@ from typing import Optional, Sequence
from mmengine.structures import LabelData from mmengine.structures import LabelData
from mmcls.registry import METRICS from mmpretrain.registry import METRICS
from .multi_label import AveragePrecision, MultiLabelMetric from .multi_label import AveragePrecision, MultiLabelMetric

View File

@ -8,6 +8,8 @@ from .heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403 from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403 from .necks import * # noqa: F401,F403
from .retrievers import * # noqa: F401,F403 from .retrievers import * # noqa: F401,F403
from .selfsup import * # noqa: F401,F403
from .target_generators import * # noqa: F401,F403
from .tta import * # noqa: F401,F403 from .tta import * # noqa: F401,F403
from .utils import * # noqa: F401,F403 from .utils import * # noqa: F401,F403

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn import torch.nn as nn
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -9,7 +9,7 @@ from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList from mmengine.model import BaseModule, ModuleList
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import (BEiTAttention, resize_pos_embed, from ..utils import (BEiTAttention, resize_pos_embed,
resize_relative_position_bias_table, to_2tuple) resize_relative_position_bias_table, to_2tuple)
from .vision_transformer import TransformerEncoderLayer, VisionTransformer from .vision_transformer import TransformerEncoderLayer, VisionTransformer

View File

@ -10,7 +10,7 @@ from mmcv.cnn.bricks.transformer import AdaptivePadding
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_ from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone
from .vision_transformer import TransformerEncoderLayer from .vision_transformer import TransformerEncoderLayer

View File

@ -7,7 +7,7 @@ from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
build_norm_layer) build_norm_layer)
from mmengine.utils import digit_version from mmengine.utils import digit_version
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -9,7 +9,7 @@ import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import DropPath from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import GRN, build_norm_layer from ..utils import GRN, build_norm_layer
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -9,7 +9,7 @@ from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, Sequential from mmengine.model import BaseModule, Sequential
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import to_ntuple from ..utils import to_ntuple
from .resnet import Bottleneck as ResNetBottleneck from .resnet import Bottleneck as ResNetBottleneck
from .resnext import Bottleneck as ResNeXtBottleneck from .resnext import Bottleneck as ResNeXtBottleneck
@ -278,8 +278,8 @@ class CSPNet(BaseModule):
>>> from functools import partial >>> from functools import partial
>>> import torch >>> import torch
>>> import torch.nn as nn >>> import torch.nn as nn
>>> from mmcls.models import CSPNet >>> from mmpretrain.models import CSPNet
>>> from mmcls.models.backbones.resnet import Bottleneck >>> from mmpretrain.models.backbones.resnet import Bottleneck
>>> >>>
>>> # A simple example to build CSPNet. >>> # A simple example to build CSPNet.
>>> arch = dict( >>> arch = dict(
@ -427,7 +427,7 @@ class CSPDarkNet(CSPNet):
Default: None. Default: None.
Example: Example:
>>> from mmcls.models import CSPDarkNet >>> from mmpretrain.models import CSPDarkNet
>>> import torch >>> import torch
>>> model = CSPDarkNet(depth=53, out_indices=(0, 1, 2, 3, 4)) >>> model = CSPDarkNet(depth=53, out_indices=(0, 1, 2, 3, 4))
>>> model.eval() >>> model.eval()
@ -523,7 +523,7 @@ class CSPResNet(CSPNet):
init_cfg (dict or list[dict], optional): Initialization config dict. init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None. Default: None.
Example: Example:
>>> from mmcls.models import CSPResNet >>> from mmpretrain.models import CSPResNet
>>> import torch >>> import torch
>>> model = CSPResNet(depth=50, out_indices=(0, 1, 2, 3)) >>> model = CSPResNet(depth=50, out_indices=(0, 1, 2, 3))
>>> model.eval() >>> model.eval()
@ -645,7 +645,7 @@ class CSPResNeXt(CSPResNet):
init_cfg (dict or list[dict], optional): Initialization config dict. init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None. Default: None.
Example: Example:
>>> from mmcls.models import CSPResNeXt >>> from mmpretrain.models import CSPResNeXt
>>> import torch >>> import torch
>>> model = CSPResNeXt(depth=50, out_indices=(0, 1, 2, 3)) >>> model = CSPResNeXt(depth=50, out_indices=(0, 1, 2, 3))
>>> model.eval() >>> model.eval()

View File

@ -12,8 +12,8 @@ from mmengine.model import BaseModule, ModuleList
from mmengine.utils import to_2tuple from mmengine.utils import to_2tuple
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones.base_backbone import BaseBackbone from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import ShiftWindowMSA from ..utils import ShiftWindowMSA

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_ from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .vision_transformer import VisionTransformer from .vision_transformer import VisionTransformer

View File

@ -10,7 +10,7 @@ from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.utils import deprecated_api_warning from mmengine.utils import deprecated_api_warning
from torch import nn from torch import nn
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import LayerScale, MultiheadAttention, resize_pos_embed, to_2tuple from ..utils import LayerScale, MultiheadAttention, resize_pos_embed, to_2tuple
from .vision_transformer import VisionTransformer from .vision_transformer import VisionTransformer

View File

@ -10,7 +10,7 @@ import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import build_activation_layer, build_norm_layer from mmcv.cnn.bricks import build_activation_layer, build_norm_layer
from torch.jit.annotations import List from torch.jit.annotations import List
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from mmcv.cnn.bricks import DropPath from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import (ChannelMultiheadAttention, PositionEncodingFourier, from ..utils import (ChannelMultiheadAttention, PositionEncodingFourier,
build_norm_layer) build_norm_layer)
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -8,7 +8,7 @@ from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
build_norm_layer) build_norm_layer)
from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import LayerScale from ..utils import LayerScale
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone
from .poolformer import Pooling from .poolformer import Pooling
@ -376,7 +376,7 @@ class EfficientFormer(BaseBackbone):
Defaults to None. Defaults to None.
Example: Example:
>>> from mmcls.models import EfficientFormer >>> from mmpretrain.models import EfficientFormer
>>> import torch >>> import torch
>>> inputs = torch.rand((1, 3, 224, 224)) >>> inputs = torch.rand((1, 3, 224, 224))
>>> # build EfficientFormer backbone for classification task >>> # build EfficientFormer backbone for classification task

View File

@ -9,9 +9,9 @@ import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import ConvModule, DropPath from mmcv.cnn.bricks import ConvModule, DropPath
from mmengine.model import BaseModule, Sequential from mmengine.model import BaseModule, Sequential
from mmcls.models.backbones.base_backbone import BaseBackbone from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmcls.models.utils import InvertedResidual, SELayer, make_divisible from mmpretrain.models.utils import InvertedResidual, SELayer, make_divisible
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
class EdgeResidual(BaseModule): class EdgeResidual(BaseModule):

View File

@ -7,10 +7,10 @@ from mmcv.cnn.bricks import ConvModule, DropPath
from mmengine.model import Sequential from mmengine.model import Sequential
from torch import Tensor from torch import Tensor
from mmcls.models.backbones.base_backbone import BaseBackbone from mmpretrain.registry import MODELS
from mmcls.models.backbones.efficientnet import EdgeResidual as FusedMBConv from ..utils import InvertedResidual as MBConv
from mmcls.models.utils import InvertedResidual as MBConv from .base_backbone import BaseBackbone
from mmcls.registry import MODELS from .efficientnet import EdgeResidual as FusedMBConv
class EnhancedConvModule(ConvModule): class EnhancedConvModule(ConvModule):

View File

@ -16,8 +16,8 @@ import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from mmcv.cnn.bricks import DropPath from mmcv.cnn.bricks import DropPath
from mmcls.models.backbones.base_backbone import BaseBackbone from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import LayerScale from ..utils import LayerScale

View File

@ -4,7 +4,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.model import BaseModule, ModuleList, Sequential
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion
@ -232,7 +232,7 @@ class HRNet(BaseModule):
Example: Example:
>>> import torch >>> import torch
>>> from mmcls.models import HRNet >>> from mmpretrain.models import HRNet
>>> extra = dict( >>> extra = dict(
>>> stage1=dict( >>> stage1=dict(
>>> num_modules=1, >>> num_modules=1,

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from mmcv.cnn import build_conv_layer from mmcv.cnn import build_conv_layer
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone
@ -389,7 +389,7 @@ class InceptionV3(BaseBackbone):
Example: Example:
>>> import torch >>> import torch
>>> from mmcls.models import build_backbone >>> from mmpretrain.models import build_backbone
>>> >>>
>>> inputs = torch.rand(2, 3, 299, 299) >>> inputs = torch.rand(2, 3, 299, 299)
>>> cfg = dict(type='InceptionV3', num_classes=100) >>> cfg = dict(type='InceptionV3', num_classes=100)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn import torch.nn as nn
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -7,8 +7,8 @@ from mmcv.cnn import build_activation_layer, fuse_conv_bn
from mmcv.cnn.bricks import DropPath from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.models.backbones.base_backbone import BaseBackbone from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import build_norm_layer from ..utils import build_norm_layer

View File

@ -9,11 +9,10 @@ from mmengine.model import BaseModule
from torch import nn from torch import nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from mmcls.models.backbones.base_backbone import BaseBackbone from mmpretrain.registry import MODELS
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer from ..utils import WindowMSA, to_2tuple
from mmcls.models.utils.attention import WindowMSA from .base_backbone import BaseBackbone
from mmcls.models.utils.helpers import to_2tuple from .vision_transformer import TransformerEncoderLayer
from mmcls.registry import MODELS
class MixMIMWindowAttention(WindowMSA): class MixMIMWindowAttention(WindowMSA):

View File

@ -6,7 +6,7 @@ from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList from mmengine.model import BaseModule, ModuleList
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import to_2tuple from ..utils import to_2tuple
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -5,8 +5,8 @@ from mmcv.cnn import ConvModule
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import make_divisible from mmpretrain.models.utils import make_divisible
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -2,7 +2,7 @@
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils import InvertedResidual from ..utils import InvertedResidual
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -9,7 +9,7 @@ from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.model import BaseModule, ModuleList, Sequential
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils.se_layer import SELayer from ..utils.se_layer import SELayer
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone
@ -302,7 +302,7 @@ class MobileOne(BaseBackbone):
init_cfg (dict or list[dict], optional): Initialization config dict. init_cfg (dict or list[dict], optional): Initialization config dict.
Example: Example:
>>> from mmcls.models import MobileOne >>> from mmpretrain.models import MobileOne
>>> import torch >>> import torch
>>> x = torch.rand(1, 3, 224, 224) >>> x = torch.rand(1, 3, 224, 224)
>>> model = MobileOne("s0", out_indices=(0, 1, 2, 3)) >>> model = MobileOne("s0", out_indices=(0, 1, 2, 3))

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer from mmcv.cnn import ConvModule, build_norm_layer
from torch import nn from torch import nn
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone
from .mobilenet_v2 import InvertedResidual from .mobilenet_v2 import InvertedResidual
from .vision_transformer import TransformerEncoderLayer from .vision_transformer import TransformerEncoderLayer

View File

@ -489,7 +489,7 @@ class MViT(BaseBackbone):
Examples: Examples:
>>> import torch >>> import torch
>>> from mmcls.models import build_backbone >>> from mmpretrain.models import build_backbone
>>> >>>
>>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3]) >>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3])
>>> model = build_backbone(cfg) >>> model = build_backbone(cfg)

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -3,7 +3,7 @@ import numpy as np
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .resnet import ResNet from .resnet import ResNet
from .resnext import Bottleneck from .resnext import Bottleneck
@ -43,7 +43,7 @@ class RegNet(ResNet):
in resblocks to let them behave as identity. Default: True. in resblocks to let them behave as identity. Default: True.
Example: Example:
>>> from mmcls.models import RegNet >>> from mmpretrain.models import RegNet
>>> import torch >>> import torch
>>> self = RegNet( >>> self = RegNet(
arch=dict( arch=dict(

View File

@ -7,7 +7,7 @@ from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -8,8 +8,8 @@ from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
from mmcv.cnn.bricks.transformer import PatchEmbed as _PatchEmbed from mmcv.cnn.bricks.transformer import PatchEmbed as _PatchEmbed
from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.models.utils import SELayer, to_2tuple from mmpretrain.models.utils import SELayer, to_2tuple
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
def fuse_bn(conv_or_fc, bn): def fuse_bn(conv_or_fc, bn):
@ -88,7 +88,7 @@ class PatchEmbed(_PatchEmbed):
class GlobalPerceptron(SELayer): class GlobalPerceptron(SELayer):
"""GlobalPerceptron implemented by using ``mmcls.modes.SELayer``. """GlobalPerceptron implemented by using ``mmpretrain.modes.SELayer``.
Args: Args:
input_channels (int): The number of input (and output) channels input_channels (int): The number of input (and output) channels

View File

@ -8,7 +8,7 @@ from mmengine.model import BaseModule, Sequential
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch import nn from torch import nn
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from ..utils.se_layer import SELayer from ..utils.se_layer import SELayer
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone

View File

@ -7,7 +7,7 @@ import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import ModuleList, Sequential from mmengine.model import ModuleList, Sequential
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .resnet import Bottleneck as _Bottleneck from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet from .resnet import ResNet
@ -258,7 +258,7 @@ class Res2Net(ResNet):
Defaults to None. Defaults to None.
Example: Example:
>>> from mmcls.models import Res2Net >>> from mmpretrain.models import Res2Net
>>> import torch >>> import torch
>>> model = Res2Net(depth=50, >>> model = Res2Net(depth=50,
... scales=4, ... scales=4,

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .resnet import Bottleneck as _Bottleneck from .resnet import Bottleneck as _Bottleneck
from .resnet import ResLayer, ResNetV1d from .resnet import ResLayer, ResNetV1d

View File

@ -9,7 +9,7 @@ from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init from mmengine.model.weight_init import constant_init
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone
eps = 1.0e-5 eps = 1.0e-5
@ -438,7 +438,7 @@ class ResNet(BaseBackbone):
in resblocks to let them behave as identity. Default: True. in resblocks to let them behave as identity. Default: True.
Example: Example:
>>> from mmcls.models import ResNet >>> from mmpretrain.models import ResNet
>>> import torch >>> import torch
>>> self = ResNet(depth=18) >>> self = ResNet(depth=18)
>>> self.eval() >>> self.eval()

View File

@ -2,7 +2,7 @@
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .resnet import ResNet from .resnet import ResNet

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcls.registry import MODELS from mmpretrain.registry import MODELS
from .resnet import Bottleneck as _Bottleneck from .resnet import Bottleneck as _Bottleneck
from .resnet import ResLayer, ResNet from .resnet import ResLayer, ResNet

Some files were not shown because too many files have changed in this diff Show More