Rename the package name to `mmpretrain`.

pull/1371/head
mzr1996 2023-02-17 11:31:08 +08:00
parent 8352951f3d
commit 0979e78573
324 changed files with 721 additions and 691 deletions

View File

@ -17,10 +17,10 @@ from modelindex.load_model_index import load
from rich.console import Console
from rich.table import Table
from mmcls.apis import init_model
from mmcls.datasets import CIFAR10, CIFAR100, ImageNet
from mmcls.utils import register_all_modules
from mmcls.visualization import ClsVisualizer
from mmpretrain.apis import init_model
from mmpretrain.datasets import CIFAR10, CIFAR100, ImageNet
from mmpretrain.utils import register_all_modules
from mmpretrain.visualization import ClsVisualizer
console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2]

View File

@ -18,9 +18,9 @@ from modelindex.load_model_index import load
from rich.console import Console
from rich.table import Table
from mmcls.datasets.builder import build_dataloader
from mmcls.datasets.pipelines import Compose
from mmcls.models.builder import build_classifier
from mmpretrain.datasets.builder import build_dataloader
from mmpretrain.datasets.pipelines import Compose
from mmpretrain.models.builder import build_classifier
console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2]

View File

@ -47,7 +47,7 @@ def ckpt_to_state_dict(checkpoint, key=None):
if key is not None:
state_dict = checkpoint[key]
elif 'state_dict' in checkpoint:
# try mmcls style
# try mmpretrain style
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
@ -149,7 +149,7 @@ def main():
if args.path.suffix in ['.json', '.py', '.yml']:
from mmengine.runner import get_state_dict
from mmcls.apis import init_model
from mmpretrain.apis import init_model
model = init_model(args.path, device='cpu')
state_dict = get_state_dict(model)
else:

View File

@ -55,7 +55,7 @@ def state_dict_from_cfg_or_ckpt(path, state_key=None):
if path.suffix in ['.json', '.py', '.yml']:
from mmengine.runner import get_state_dict
from mmcls.apis import init_model
from mmpretrain.apis import init_model
model = init_model(path, device='cpu')
model.init_weights()
return get_state_dict(model)

View File

@ -84,8 +84,8 @@ def get_flops(config_path):
from mmengine.dataset import Compose
from mmengine.registry import DefaultScope
import mmcls.datasets # noqa: F401
from mmcls.apis import init_model
import mmpretrain.datasets # noqa: F401
from mmpretrain.apis import init_model
cfg = Config.fromfile(config_path)
@ -98,7 +98,7 @@ def get_flops(config_path):
# The image shape of CIFAR is (32, 32, 3)
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
with DefaultScope.overwrite_default_scope('mmcls'):
with DefaultScope.overwrite_default_scope('mmpretrain'):
data = Compose(test_dataset.pipeline)({
'img':
np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)

2
.gitignore vendored
View File

@ -127,7 +127,7 @@ venv.bak/
/work_dirs
/projects/*/work_dirs
/projects/*/data
/mmcls/.mim
/mmpretrain/.mim
.DS_Store
# Pytorch

View File

@ -1,4 +1,4 @@
include requirements/*.txt
include mmcls/.mim/model-index.yml
recursive-include mmcls/.mim/configs *.py *.yml
recursive-include mmcls/.mim/tools *.py *.sh
include mmpretrain/.mim/model-index.yml
recursive-include mmpretrain/.mim/configs *.py *.yml
recursive-include mmpretrain/.mim/tools *.py *.sh

View File

@ -1,5 +1,5 @@
# defaults to use registries in mmcls
default_scope = 'mmcls'
# defaults to use registries in mmpretrain
default_scope = 'mmpretrain'
# configure default hooks
default_hooks = dict(

View File

@ -1,4 +1,4 @@
from mmcls.models import build_classifier
from mmpretrain.models import build_classifier
model = dict(
type='ImageClassifier',

View File

@ -2,11 +2,11 @@ _base_ = ['../_base_/datasets/voc_bs16.py', '../_base_/default_runtime.py']
# Pre-trained Checkpoint Path
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth' # noqa
# If you want to use the pre-trained weight of ResNet101-CutMix from
# the originary repo(https://github.com/Kevinz-code/CSRA). Script of
# 'tools/convert_models/torchvision_to_mmcls.py' can help you convert weight
# into mmcls format. The mAP result would hit 95.5 by using the weight.
# checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT'
# If you want to use the pre-trained weight of ResNet101-CutMix from the
# originary repo(https://github.com/Kevinz-code/CSRA). Script of
# 'tools/model_converters/torchvision_to_mmpretrain.py' can help you convert
# weight into mmpretrain format. The mAP result would hit 95.5 by using the
# weight. checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT'
# model settings
model = dict(

View File

@ -51,7 +51,7 @@ val_cfg = dict()
test_cfg = dict()
# runtime settings
default_scope = 'mmcls'
default_scope = 'mmpretrain'
default_hooks = dict(
# record the time of every iteration.

View File

@ -4,7 +4,7 @@ from argparse import ArgumentParser
from mmengine.fileio import dump
from rich import print_json
from mmcls.apis import ImageClassificationInferencer
from mmpretrain.apis import ImageClassificationInferencer
def main():
@ -30,7 +30,7 @@ def main():
raise ValueError(
f'Unavailable model "{args.model}", you can specify find a model '
'name or a config file or find a model name from '
'https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html#all-checkpoints' # noqa: E501
'https://mmpretrain.readthedocs.io/en/1.x/modelzoo_statistics.html#all-checkpoints' # noqa: E501
)
result = inferencer(args.img, show=args.show, show_dir=args.show_dir)[0]
# show the results

View File

@ -22,12 +22,12 @@ sys.path.insert(0, os.path.abspath('../../'))
# -- Project information -----------------------------------------------------
project = 'MMClassification'
project = 'MMPretrain'
copyright = '2020, OpenMMLab'
author = 'MMClassification Authors'
author = 'MMPretrain Authors'
# The full version, including alpha/beta/rc tags
version_file = '../../mmcls/version.py'
version_file = '../../mmpretrain/version.py'
def get_version():
@ -92,25 +92,25 @@ html_theme_options = {
'menu': [
{
'name': 'GitHub',
'url': 'https://github.com/open-mmlab/mmclassification'
'url': 'https://github.com/open-mmlab/mmpretrain'
},
{
'name': 'Colab Tutorials',
'children': [
{'name': 'Train and inference with shell commands',
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_tools.ipynb'},
'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_tools.ipynb'},
{'name': 'Train and inference with Python APIs',
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_python.ipynb'},
'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_python.ipynb'},
]
},
{
'name': 'Version',
'children': [
{'name': 'MMClassification 0.x',
'url': 'https://mmclassification.readthedocs.io/en/latest/',
{'name': 'MMPretrain 0.x',
'url': 'https://mmpretrain.readthedocs.io/en/latest/',
'description': 'master branch'},
{'name': 'MMClassification 1.x',
'url': 'https://mmclassification.readthedocs.io/en/dev-1.x/',
{'name': 'MMPretrain 1.x',
'url': 'https://mmpretrain.readthedocs.io/en/dev-1.x/',
'description': '1.x branch'},
],
}
@ -138,7 +138,7 @@ html_js_files = [
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'mmclsdoc'
htmlhelp_basename = 'mmpretraindoc'
# -- Options for LaTeX output ------------------------------------------------
@ -160,16 +160,14 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(root_doc, 'mmcls.tex', 'MMClassification Documentation', author,
'manual'),
(root_doc, 'mmpretrain.tex', 'MMPretrain Documentation', author, 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1)
]
man_pages = [(root_doc, 'mmpretrain', 'MMPretrain Documentation', [author], 1)]
# -- Options for Texinfo output ----------------------------------------------
@ -177,7 +175,7 @@ man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1)
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(root_doc, 'mmcls', 'MMClassification Documentation', author, 'mmcls',
(root_doc, 'mmpretrain', 'MMPretrain Documentation', author, 'mmpretrain',
'OpenMMLab image classification toolbox and benchmark.', 'Miscellaneous'),
]

View File

@ -9,7 +9,7 @@ from tabulate import tabulate
MMCLS_ROOT = Path(__file__).absolute().parents[2]
PAPERS_ROOT = Path('papers') # Path to save generated paper pages.
GITHUB_PREFIX = 'https://github.com/open-mmlab/mmclassification/blob/1.x/'
GITHUB_PREFIX = 'https://github.com/open-mmlab/mmpretrain/blob/1.x/'
MODELZOO_TEMPLATE = """
# Model Zoo Summary

View File

@ -22,12 +22,12 @@ sys.path.insert(0, os.path.abspath('../..'))
# -- Project information -----------------------------------------------------
project = 'MMClassification'
project = 'MMPretrain'
copyright = '2020, OpenMMLab'
author = 'MMClassification Authors'
author = 'MMPretrain Authors'
# The full version, including alpha/beta/rc tags
version_file = '../../mmcls/version.py'
version_file = '../../mmpretrain/version.py'
def get_version():
@ -92,25 +92,25 @@ html_theme_options = {
'menu': [
{
'name': 'GitHub',
'url': 'https://github.com/open-mmlab/mmclassification'
'url': 'https://github.com/open-mmlab/mmpretrain'
},
{
'name': 'Colab 教程',
'children': [
{'name': '用命令行工具训练和推理',
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_tools.ipynb'},
'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_tools.ipynb'},
{'name': '用 Python API 训练和推理',
'url': 'https://colab.research.google.com/github/mzr1996/mmclassification-tutorial/blob/master/1.x/MMClassification_python.ipynb'},
'url': 'https://colab.research.google.com/github/mzr1996/mmpretrain-tutorial/blob/master/1.x/MMPretrain_python.ipynb'},
]
},
{
'name': 'Version',
'children': [
{'name': 'MMClassification 0.x',
'url': 'https://mmclassification.readthedocs.io/zh_CN/latest/',
{'name': 'MMPretrain 0.x',
'url': 'https://mmpretrain.readthedocs.io/zh_CN/latest/',
'description': 'master branch'},
{'name': 'MMClassification 1.x',
'url': 'https://mmclassification.readthedocs.io/zh_CN/dev-1.x/',
{'name': 'MMPretrain 1.x',
'url': 'https://mmpretrain.readthedocs.io/zh_CN/dev-1.x/',
'description': '1.x branch'},
],
}
@ -138,7 +138,7 @@ html_js_files = [
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'mmclsdoc'
htmlhelp_basename = 'mmpretraindoc'
# -- Options for LaTeX output ------------------------------------------------
@ -164,16 +164,14 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(root_doc, 'mmcls.tex', 'MMClassification Documentation', author,
'manual'),
(root_doc, 'mmpretrain.tex', 'MMPretrain Documentation', author, 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1)
]
man_pages = [(root_doc, 'mmpretrain', 'MMPretrain Documentation', [author], 1)]
# -- Options for Texinfo output ----------------------------------------------
@ -181,7 +179,7 @@ man_pages = [(root_doc, 'mmcls', 'MMClassification Documentation', [author], 1)
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(root_doc, 'mmcls', 'MMClassification Documentation', author, 'mmcls',
(root_doc, 'mmpretrain', 'MMPretrain Documentation', author, 'mmpretrain',
'OpenMMLab image classification toolbox and benchmark.', 'Miscellaneous'),
]

View File

@ -9,7 +9,7 @@ from tabulate import tabulate
MMCLS_ROOT = Path(__file__).absolute().parents[2]
PAPERS_ROOT = Path('papers') # Path to save generated paper pages.
GITHUB_PREFIX = 'https://github.com/open-mmlab/mmclassification/blob/1.x/'
GITHUB_PREFIX = 'https://github.com/open-mmlab/mmpretrain/blob/1.x/'
MODELZOO_TEMPLATE = """
# 模型库统计

View File

@ -1,39 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import warnings
from mmengine import DefaultScope
def register_all_modules(init_default_scope: bool = True) -> None:
"""Register all modules in mmcls into the registries.
Args:
init_default_scope (bool): Whether initialize the mmcls default scope.
If True, the global default scope will be set to `mmcls`, and all
registries will build modules from mmcls's registry node. To
understand more about the registry, please refer to
https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
Defaults to True.
""" # noqa
import mmcls.datasets # noqa: F401,F403
import mmcls.engine # noqa: F401,F403
import mmcls.evaluation # noqa: F401,F403
import mmcls.models # noqa: F401,F403
import mmcls.structures # noqa: F401,F403
import mmcls.visualization # noqa: F401,F403
if not init_default_scope:
return
current_scope = DefaultScope.get_current_instance()
if current_scope is None:
DefaultScope.get_instance('mmcls', scope_name='mmcls')
elif current_scope.scope_name != 'mmcls':
warnings.warn(f'The current default scope "{current_scope.scope_name}"'
' is not "mmcls", `register_all_modules` will force the '
'current default scope to be "mmcls". If this is not '
'expected, please set `init_default_scope=False`.')
# avoid name conflict
new_instance_name = f'mmcls-{datetime.datetime.now()}'
DefaultScope.get_instance(new_instance_name, scope_name='mmcls')

View File

@ -11,8 +11,8 @@ from mmengine.infer import BaseInferencer
from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint
from mmcls.registry import TRANSFORMS
from mmcls.structures import ClsDataSample
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import ClsDataSample
from .model import get_model, init_model, list_models
ModelType = Union[BaseModel, str, Config]
@ -63,7 +63,7 @@ class ImageClassificationInferencer(BaseInferencer):
Example:
1. Use a pre-trained model in MMClassification to inference an image.
>>> from mmcls import ImageClassificationInferencer
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
>>> inferencer('demo/demo.JPEG')
[{'pred_score': array([...]),
@ -74,7 +74,7 @@ class ImageClassificationInferencer(BaseInferencer):
2. Use a config file and checkpoint to inference multiple images on GPU,
and save the visualization results in a folder.
>>> from mmcls import ImageClassificationInferencer
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer(
model='configs/resnet/resnet50_8xb32_in1k.py',
weights='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
@ -107,7 +107,7 @@ class ImageClassificationInferencer(BaseInferencer):
else:
raise TypeError(
'The `model` can be a name of model and you can use '
'`mmcls.list_models` to get an available name. It can '
'`mmpretrain.list_models` to get an available name. It can '
'also be a Config object or a path to the config file.')
model.eval()
@ -185,7 +185,7 @@ class ImageClassificationInferencer(BaseInferencer):
return None
if self.visualizer is None:
from mmcls.visualization import ClsVisualizer
from mmpretrain.visualization import ClsVisualizer
self.visualizer = ClsVisualizer()
if self.classes is not None:
self.visualizer._dataset_meta = dict(classes=self.classes)

View File

@ -15,7 +15,7 @@ from modelindex.models.Model import Model
class ModelHub:
"""A hub to host the meta information of all pre-defined models."""
_models_dict = {}
__mmcls_registered = False
__mmpretrain_registered = False
@classmethod
def register_model_index(cls,
@ -52,7 +52,7 @@ class ModelHub:
Returns:
modelindex.models.Model: The metainfo of the specified model.
"""
cls._register_mmcls_models()
cls._register_mmpretrain_models()
# lazy load config
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
if metainfo is None:
@ -75,15 +75,15 @@ class ModelHub:
return config_path
@classmethod
def _register_mmcls_models(cls):
# register models in mmcls
if not cls.__mmcls_registered:
def _register_mmpretrain_models(cls):
# register models in mmpretrain
if not cls.__mmpretrain_registered:
from mmengine.utils import get_installed_path
mmcls_root = Path(get_installed_path('mmcls'))
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
mmpretrain_root = Path(get_installed_path('mmpretrain'))
model_index_path = mmpretrain_root / '.mim' / 'model-index.yml'
ModelHub.register_model_index(
model_index_path, config_prefix=mmcls_root / '.mim')
cls.__mmcls_registered = True
model_index_path, config_prefix=mmpretrain_root / '.mim')
cls.__mmpretrain_registered = True
@classmethod
def has(cls, model_name):
@ -118,7 +118,7 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None))
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
model = MODELS.build(config.model)
if checkpoint is not None:
@ -130,13 +130,13 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
# Don't set CLASSES if the model is headless.
pass
elif 'dataset_meta' in checkpoint.get('meta', {}):
# mmcls 1.x
# mmpretrain 1.x
model.CLASSES = checkpoint['meta']['dataset_meta'].get('classes')
elif 'CLASSES' in checkpoint.get('meta', {}):
# mmcls < 1.x
# mmpretrain < 1.x or mmselfsup < 1.x
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets.categories import IMAGENET_CATEGORIES
from mmpretrain.datasets.categories import IMAGENET_CATEGORIES
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
@ -165,7 +165,7 @@ def get_model(model_name, pretrained=False, device=None, **kwargs):
Get a ResNet-50 model and extract images feature:
>>> import torch
>>> from mmcls import get_model
>>> from mmpretrain import get_model
>>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs)
@ -178,7 +178,7 @@ def get_model(model_name, pretrained=False, device=None, **kwargs):
Get Swin-Transformer model with pre-trained weights and inference:
>>> from mmcls import get_model, inference_model
>>> from mmpretrain import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class'])
@ -201,7 +201,7 @@ def get_model(model_name, pretrained=False, device=None, **kwargs):
def list_models(pattern=None) -> List[str]:
"""List all models available in MMClassification.
"""List all models available in MMPretrain.
Args:
pattern (str | None): A wildcard pattern to match model names.
@ -212,12 +212,12 @@ def list_models(pattern=None) -> List[str]:
Examples:
List all models:
>>> from mmcls import list_models
>>> from mmpretrain import list_models
>>> print(list_models())
List ResNet-50 models on ImageNet-1k dataset:
>>> from mmcls import list_models
>>> from mmpretrain import list_models
>>> print(list_models('resnet*in1k'))
['resnet50_8xb32_in1k',
'resnet50_8xb32-fp16_in1k',
@ -225,7 +225,7 @@ def list_models(pattern=None) -> List[str]:
'resnet50_8xb256-rsb-a2-300e_in1k',
'resnet50_8xb256-rsb-a3-100e_in1k']
"""
ModelHub._register_mmcls_models()
ModelHub._register_mmpretrain_models()
if pattern is None:
return sorted(list(ModelHub._models_dict.keys()))
# Always match keys with any postfix.

View File

@ -7,7 +7,7 @@ import mmengine
import numpy as np
from mmengine.dataset import BaseDataset as _BaseDataset
from mmcls.registry import DATASETS, TRANSFORMS
from mmpretrain.registry import DATASETS, TRANSFORMS
def expanduser(path):

View File

@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcls.registry import DATASETS
from mmpretrain.registry import DATASETS
def build_dataset(cfg):
"""Build dataset.
Examples:
>>> from mmcls.datasets import build_dataset
>>> from mmpretrain.datasets import build_dataset
>>> mnist_train = build_dataset(
... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False))
>>> print(mnist_train)

View File

@ -7,7 +7,7 @@ import numpy as np
from mmengine.fileio import (LocalBackend, exists, get, get_file_backend,
join_path)
from mmcls.registry import DATASETS
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
from .categories import CIFAR10_CATEGORIES, CIFAR100_CATEGORIES
from .utils import check_md5, download_and_extract_archive

View File

@ -3,7 +3,7 @@ from typing import List
from mmengine import get_file_backend, list_from_file
from mmcls.registry import DATASETS
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
from .categories import CUB_CATEGORIES
@ -51,7 +51,7 @@ class CUB(BaseDataset):
Examples:
>>> from mmcls.datasets import CUB
>>> from mmpretrain.datasets import CUB
>>> cub_train_cfg = dict(data_root='data/CUB_200_2011', test_mode=True)
>>> cub_train = CUB(**cub_train_cfg)
>>> cub_train

View File

@ -5,7 +5,7 @@ from mmengine.fileio import (BaseStorageBackend, get_file_backend,
list_from_file)
from mmengine.logging import MMLogger
from mmcls.registry import DATASETS
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset

View File

@ -4,7 +4,7 @@ import copy
import numpy as np
from mmengine.dataset import BaseDataset, force_full_init
from mmcls.registry import DATASETS
from mmpretrain.registry import DATASETS
@DATASETS.register_module()

View File

@ -3,7 +3,7 @@ from typing import Optional, Union
from mmengine.logging import MMLogger
from mmcls.registry import DATASETS
from mmpretrain.registry import DATASETS
from .categories import IMAGENET_CATEGORIES
from .custom import CustomDataset

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import get_file_backend, list_from_file
from mmcls.datasets.base_dataset import BaseDataset
from mmcls.registry import DATASETS
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
@ -35,7 +35,7 @@ class InShop(BaseDataset):
**kwargs: Other keyword arguments in :class:`BaseDataset`.
Examples:
>>> from mmcls.datasets import InShop
>>> from mmpretrain.datasets import InShop
>>>
>>> # build train InShop dataset
>>> inshop_train_cfg = dict(data_root='data/inshop', split='train')

View File

@ -8,7 +8,7 @@ import numpy as np
import torch
from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path
from mmcls.registry import DATASETS
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
from .utils import (download_and_extract_archive, open_maybe_compressed_file,

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmcls.registry import DATASETS
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset

View File

@ -60,7 +60,7 @@ class MultiTaskDataset:
Assume we put our dataset in the ``data/mydataset`` folder in the
repository and organize it as the below format: ::
mmclassification/
mmpretrain/
data
mydataset
annotation
@ -81,7 +81,7 @@ class MultiTaskDataset:
.. code:: python
>>> from mmcls.datasets import build_dataset
>>> from mmpretrain.datasets import build_dataset
>>> train_cfg = dict(
... type="MultiTaskDataset",
... ann_file="annotation/train.json",
@ -94,7 +94,7 @@ class MultiTaskDataset:
Or we can put all files in the same folder: ::
mmclassification/
mmpretrain/
data
mydataset
train.json
@ -109,7 +109,7 @@ class MultiTaskDataset:
.. code:: python
>>> from mmcls.datasets import build_dataset
>>> from mmpretrain.datasets import build_dataset
>>> train_cfg = dict(
... type="MultiTaskDataset",
... ann_file="train.json",
@ -133,8 +133,8 @@ class MultiTaskDataset:
``data_root`` for the ``"img_path"`` field in the annotation file.
Defaults to None.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
represents a operation defined in
:mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple.
test_mode (bool): in train mode or test mode. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.

View File

@ -5,7 +5,7 @@ import torch
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
from torch.utils.data import Sampler
from mmcls.registry import DATA_SAMPLERS
from mmpretrain.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()

View File

@ -11,7 +11,7 @@ from mmcv.transforms import BaseTransform, Compose, RandomChoice
from mmcv.transforms.utils import cache_randomness
from mmengine.utils import is_list_of, is_seq_of
from mmcls.registry import TRANSFORMS
from mmpretrain.registry import TRANSFORMS
def merge_hparams(policy: dict, hparams: dict) -> dict:
@ -125,7 +125,7 @@ class RandAugment(BaseTransform):
time, and magnitude_level of every policy is 6 (total is 10 by default)
>>> import numpy as np
>>> from mmcls.datasets import RandAugment
>>> from mmpretrain.datasets import RandAugment
>>> transform = RandAugment(
... policies='timm_increasing',
... num_policies=2,

View File

@ -9,8 +9,8 @@ from mmcv.transforms import BaseTransform
from mmengine.utils import is_str
from PIL import Image
from mmcls.registry import TRANSFORMS
from mmcls.structures import ClsDataSample, MultiTaskDataSample
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
def to_tensor(data):
@ -53,8 +53,8 @@ class PackClsInputs(BaseTransform):
**Added Keys:**
- inputs (:obj:`torch.Tensor`): The forward data of models.
- data_samples (:obj:`~mmcls.structures.ClsDataSample`): The annotation
info of the sample.
- data_samples (:obj:`~mmpretrain.structures.ClsDataSample`): The
annotation info of the sample.
Args:
meta_keys (Sequence[str]): The meta keys to be saved in the

View File

@ -11,7 +11,7 @@ import numpy as np
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmcls.registry import TRANSFORMS
from mmpretrain.registry import TRANSFORMS
try:
import albumentations
@ -1008,7 +1008,7 @@ class Lighting(BaseTransform):
return repr_str
# 'Albu' is used in previous versions of mmcls, here is for compatibility
# 'Albu' is used in previous versions of mmpretrain, here is for compatibility
# users can use both 'Albumentations' and 'Albu'.
@TRANSFORMS.register_module(['Albumentations', 'Albu'])
class Albumentations(BaseTransform):
@ -1055,13 +1055,13 @@ class Albumentations(BaseTransform):
Args:
transforms (List[Dict]): List of albumentations transform configs.
keymap (Optional[Dict]): Mapping of mmcls to albumentations fields,
in format {'input key':'albumentation-style key'}. Defaults to
None.
keymap (Optional[Dict]): Mapping of mmpretrain to albumentations
fields, in format {'input key':'albumentation-style key'}.
Defaults to None.
Example:
>>> import mmcv
>>> from mmcls.datasets import Albumentations
>>> from mmpretrain.datasets import Albumentations
>>> transforms = [
... dict(
... type='ShiftScaleRotate',

View File

@ -4,7 +4,7 @@ from typing import List, Optional, Union
from mmengine import get_file_backend, list_from_file
from mmcls.registry import DATASETS
from mmpretrain.registry import DATASETS
from .base_dataset import expanduser
from .categories import VOC2007_CATEGORIES
from .multi_label import MultiLabelDataset

View File

@ -2,7 +2,7 @@
from mmengine.hooks import Hook
from mmengine.utils import is_seq_of
from mmcls.registry import HOOKS
from mmpretrain.registry import HOOKS
@HOOKS.register_module()

View File

@ -8,7 +8,7 @@ from mmengine.hooks import EMAHook as BaseEMAHook
from mmengine.logging import MMLogger
from mmengine.runner import Runner
from mmcls.registry import HOOKS
from mmpretrain.registry import HOOKS
@HOOKS.register_module()

View File

@ -3,8 +3,8 @@ import numpy as np
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmcls.models.heads import ArcFaceClsHead
from mmcls.registry import HOOKS
from mmpretrain.models.heads import ArcFaceClsHead
from mmpretrain.registry import HOOKS
@HOOKS.register_module()

View File

@ -20,7 +20,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
from torch.utils.data import DataLoader
from mmcls.registry import HOOKS
from mmpretrain.registry import HOOKS
DATA_BATCH = Optional[Sequence[dict]]

View File

@ -4,8 +4,8 @@ import warnings
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmcls.models import BaseRetriever
from mmcls.registry import HOOKS
from mmpretrain.models import BaseRetriever
from mmpretrain.registry import HOOKS
@HOOKS.register_module()
@ -27,5 +27,6 @@ class PrepareProtoBeforeValLoopHook(Hook):
model.prepare_prototype()
else:
warnings.warn(
'Only the `mmcls.models.retrievers.BaseRetriever` can execute '
f'`PrepareRetrieverPrototypeHook`, but got `{type(model)}`')
'Only the `mmpretrain.models.retrievers.BaseRetriever` '
'can execute `PrepareRetrieverPrototypeHook`, but got '
f'`{type(model)}`')

View File

@ -6,8 +6,8 @@ from mmcv.transforms import Compose
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmcls.models.utils import RandomBatchAugment
from mmcls.registry import HOOKS, MODEL_WRAPPERS, MODELS
from mmpretrain.models.utils import RandomBatchAugment
from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS
@HOOKS.register_module()
@ -25,9 +25,9 @@ class SwitchRecipeHook(Hook):
train dataset. If not specified, keep the original settings.
- ``batch_augments`` (dict | None, optional): The new batch
augmentations of during training. See :mod:`Batch Augmentations
<mmcls.models.utils.batch_augments>` for more details. If None,
disable batch augmentations. If not specified, keep the original
settings.
<mmpretrain.models.utils.batch_augments>` for more details.
If None, disable batch augmentations. If not specified, keep the
original settings.
- ``loss`` (dict, optional): The new loss module config. If not
specified, keep the original settings.

View File

@ -8,8 +8,8 @@ from mmengine.hooks import Hook
from mmengine.runner import EpochBasedTrainLoop, Runner
from mmengine.visualization import Visualizer
from mmcls.registry import HOOKS
from mmcls.structures import ClsDataSample
from mmpretrain.registry import HOOKS
from mmpretrain.structures import ClsDataSample
@HOOKS.register_module()
@ -30,7 +30,7 @@ class VisualizationHook(Hook):
in the testing process. If None, handle with the backends of the
visualizer. Defaults to None.
**kwargs: other keyword arguments of
:meth:`mmcls.visualization.ClsVisualizer.add_datasample`.
:meth:`mmpretrain.visualization.ClsVisualizer.add_datasample`.
"""
def __init__(self,

View File

@ -19,7 +19,7 @@ import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from mmcls.registry import OPTIMIZERS
from mmpretrain.registry import OPTIMIZERS
@OPTIMIZERS.register_module()

View File

@ -61,7 +61,7 @@ import math
import torch
from torch.optim import Optimizer
from mmcls.registry import OPTIMIZERS
from mmpretrain.registry import OPTIMIZERS
@OPTIMIZERS.register_module()

View File

@ -1,2 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .functional import * # noqa: F401,F403
from .metrics import * # noqa: F401,F403

View File

@ -7,7 +7,7 @@ from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger
from mmengine.structures import LabelData
from mmcls.registry import METRICS
from mmpretrain.registry import METRICS
from .single_label import _precision_recall_f1_support, to_tensor
@ -77,7 +77,7 @@ class MultiLabelMetric(BaseMetric):
Examples:
>>> import torch
>>> from mmcls.evaluation import MultiLabelMetric
>>> from mmpretrain.evaluation import MultiLabelMetric
>>> # ------ The Basic Usage for category indices labels -------
>>> y_pred = [[0], [1], [0, 1], [3]]
>>> y_true = [[0, 3], [0, 2], [1], [3]]
@ -114,7 +114,7 @@ class MultiLabelMetric(BaseMetric):
(tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8))
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample
>>> from mmpretrain.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator
>>> data_sampels = [
... ClsDataSample().set_pred_score(pred).set_gt_score(gt)
@ -466,7 +466,7 @@ class AveragePrecision(BaseMetric):
Examples:
>>> import torch
>>> from mmcls.evaluation import AveragePrecision
>>> from mmpretrain.evaluation import AveragePrecision
>>> # --------- The Basic Usage for one-hot pred scores ---------
>>> y_pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2],
... [0.1, 0.2, 0.2, 0.1],
@ -479,7 +479,7 @@ class AveragePrecision(BaseMetric):
>>> AveragePrecision.calculate(y_pred, y_true)
tensor(70.833)
>>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample
>>> from mmpretrain.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... ClsDataSample().set_pred_score(i).set_gt_score(j)

View File

@ -3,7 +3,7 @@ from typing import Dict, Sequence
from mmengine.evaluator import BaseMetric
from mmcls.registry import METRICS
from mmpretrain.registry import METRICS
@METRICS.register_module()
@ -14,7 +14,7 @@ class MultiTasksMetric(BaseMetric):
and the values is a list of the metric corresponds to this task
Examples:
>>> import torch
>>> from mmcls.evaluation import MultiTasksMetric
>>> from mmpretrain.evaluation import MultiTasksMetric
# -------------------- The Basic Usage --------------------
>>>task_metrics = {
'task0': [dict(type='Accuracy', topk=(1, ))],

View File

@ -8,7 +8,7 @@ from mmengine.evaluator import BaseMetric
from mmengine.structures import LabelData
from mmengine.utils import is_seq_of
from mmcls.registry import METRICS
from mmpretrain.registry import METRICS
from .single_label import to_tensor
@ -33,7 +33,7 @@ class RetrievalRecall(BaseMetric):
Use in the code:
>>> import torch
>>> from mmcls.evaluation import RetrievalRecall
>>> from mmpretrain.evaluation import RetrievalRecall
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [[0], [1], [2], [3]]
>>> y_true = [[0, 1], [2], [1], [0, 3]]
@ -48,7 +48,7 @@ class RetrievalRecall(BaseMetric):
[tensor(9.3000), tensor(48.4000)]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample
>>> from mmpretrain.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... ClsDataSample().set_gt_label([0, 1]).set_pred_score(

View File

@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F
from mmengine.evaluator import BaseMetric
from mmcls.registry import METRICS
from mmpretrain.registry import METRICS
def to_tensor(value):
@ -91,7 +91,7 @@ class Accuracy(BaseMetric):
Examples:
>>> import torch
>>> from mmcls.evaluation import Accuracy
>>> from mmpretrain.evaluation import Accuracy
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [0, 2, 1, 3]
>>> y_true = [0, 1, 2, 3]
@ -104,7 +104,7 @@ class Accuracy(BaseMetric):
[[tensor([9.9000])], [tensor([51.5000])]]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample
>>> from mmpretrain.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... ClsDataSample().set_gt_label(0).set_pred_score(torch.rand(10))
@ -343,7 +343,7 @@ class SingleLabelMetric(BaseMetric):
Examples:
>>> import torch
>>> from mmcls.evaluation import SingleLabelMetric
>>> from mmpretrain.evaluation import SingleLabelMetric
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3]
@ -358,7 +358,7 @@ class SingleLabelMetric(BaseMetric):
(tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample
>>> from mmpretrain.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... ClsDataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
@ -606,7 +606,7 @@ class ConfusionMatrix(BaseMetric):
1. The basic usage.
>>> import torch
>>> from mmcls.evaluation import ConfusionMatrix
>>> from mmpretrain.evaluation import ConfusionMatrix
>>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3]
>>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4)

View File

@ -3,7 +3,7 @@ from typing import Optional, Sequence
from mmengine.structures import LabelData
from mmcls.registry import METRICS
from mmpretrain.registry import METRICS
from .multi_label import AveragePrecision, MultiLabelMetric

View File

@ -8,6 +8,8 @@ from .heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .retrievers import * # noqa: F401,F403
from .selfsup import * # noqa: F401,F403
from .target_generators import * # noqa: F401,F403
from .tta import * # noqa: F401,F403
from .utils import * # noqa: F401,F403

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -9,7 +9,7 @@ from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils import (BEiTAttention, resize_pos_embed,
resize_relative_position_bias_table, to_2tuple)
from .vision_transformer import TransformerEncoderLayer, VisionTransformer

View File

@ -10,7 +10,7 @@ from mmcv.cnn.bricks.transformer import AdaptivePadding
from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
from .vision_transformer import TransformerEncoderLayer

View File

@ -7,7 +7,7 @@ from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
build_norm_layer)
from mmengine.utils import digit_version
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -9,7 +9,7 @@ import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils import GRN, build_norm_layer
from .base_backbone import BaseBackbone

View File

@ -9,7 +9,7 @@ from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, Sequential
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils import to_ntuple
from .resnet import Bottleneck as ResNetBottleneck
from .resnext import Bottleneck as ResNeXtBottleneck
@ -278,8 +278,8 @@ class CSPNet(BaseModule):
>>> from functools import partial
>>> import torch
>>> import torch.nn as nn
>>> from mmcls.models import CSPNet
>>> from mmcls.models.backbones.resnet import Bottleneck
>>> from mmpretrain.models import CSPNet
>>> from mmpretrain.models.backbones.resnet import Bottleneck
>>>
>>> # A simple example to build CSPNet.
>>> arch = dict(
@ -427,7 +427,7 @@ class CSPDarkNet(CSPNet):
Default: None.
Example:
>>> from mmcls.models import CSPDarkNet
>>> from mmpretrain.models import CSPDarkNet
>>> import torch
>>> model = CSPDarkNet(depth=53, out_indices=(0, 1, 2, 3, 4))
>>> model.eval()
@ -523,7 +523,7 @@ class CSPResNet(CSPNet):
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from mmcls.models import CSPResNet
>>> from mmpretrain.models import CSPResNet
>>> import torch
>>> model = CSPResNet(depth=50, out_indices=(0, 1, 2, 3))
>>> model.eval()
@ -645,7 +645,7 @@ class CSPResNeXt(CSPResNet):
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from mmcls.models import CSPResNeXt
>>> from mmpretrain.models import CSPResNeXt
>>> import torch
>>> model = CSPResNeXt(depth=50, out_indices=(0, 1, 2, 3))
>>> model.eval()

View File

@ -12,8 +12,8 @@ from mmengine.model import BaseModule, ModuleList
from mmengine.utils import to_2tuple
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.registry import MODELS
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.registry import MODELS
from ..utils import ShiftWindowMSA

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .vision_transformer import VisionTransformer

View File

@ -10,7 +10,7 @@ from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.utils import deprecated_api_warning
from torch import nn
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils import LayerScale, MultiheadAttention, resize_pos_embed, to_2tuple
from .vision_transformer import VisionTransformer

View File

@ -10,7 +10,7 @@ import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import build_activation_layer, build_norm_layer
from torch.jit.annotations import List
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils import (ChannelMultiheadAttention, PositionEncodingFourier,
build_norm_layer)
from .base_backbone import BaseBackbone

View File

@ -8,7 +8,7 @@ from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
build_norm_layer)
from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils import LayerScale
from .base_backbone import BaseBackbone
from .poolformer import Pooling
@ -376,7 +376,7 @@ class EfficientFormer(BaseBackbone):
Defaults to None.
Example:
>>> from mmcls.models import EfficientFormer
>>> from mmpretrain.models import EfficientFormer
>>> import torch
>>> inputs = torch.rand((1, 3, 224, 224))
>>> # build EfficientFormer backbone for classification task

View File

@ -9,9 +9,9 @@ import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import ConvModule, DropPath
from mmengine.model import BaseModule, Sequential
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.utils import InvertedResidual, SELayer, make_divisible
from mmcls.registry import MODELS
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.models.utils import InvertedResidual, SELayer, make_divisible
from mmpretrain.registry import MODELS
class EdgeResidual(BaseModule):

View File

@ -7,10 +7,10 @@ from mmcv.cnn.bricks import ConvModule, DropPath
from mmengine.model import Sequential
from torch import Tensor
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.backbones.efficientnet import EdgeResidual as FusedMBConv
from mmcls.models.utils import InvertedResidual as MBConv
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils import InvertedResidual as MBConv
from .base_backbone import BaseBackbone
from .efficientnet import EdgeResidual as FusedMBConv
class EnhancedConvModule(ConvModule):

View File

@ -16,8 +16,8 @@ import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from mmcv.cnn.bricks import DropPath
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.registry import MODELS
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.registry import MODELS
from ..utils import LayerScale

View File

@ -4,7 +4,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import BaseModule, ModuleList, Sequential
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion
@ -232,7 +232,7 @@ class HRNet(BaseModule):
Example:
>>> import torch
>>> from mmcls.models import HRNet
>>> from mmpretrain.models import HRNet
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from mmcv.cnn import build_conv_layer
from mmengine.model import BaseModule
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
@ -389,7 +389,7 @@ class InceptionV3(BaseBackbone):
Example:
>>> import torch
>>> from mmcls.models import build_backbone
>>> from mmpretrain.models import build_backbone
>>>
>>> inputs = torch.rand(2, 3, 299, 299)
>>> cfg = dict(type='InceptionV3', num_classes=100)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -7,8 +7,8 @@ from mmcv.cnn import build_activation_layer, fuse_conv_bn
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.registry import MODELS
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.registry import MODELS
from ..utils import build_norm_layer

View File

@ -9,11 +9,10 @@ from mmengine.model import BaseModule
from torch import nn
from torch.utils.checkpoint import checkpoint
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
from mmcls.models.utils.attention import WindowMSA
from mmcls.models.utils.helpers import to_2tuple
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils import WindowMSA, to_2tuple
from .base_backbone import BaseBackbone
from .vision_transformer import TransformerEncoderLayer
class MixMIMWindowAttention(WindowMSA):

View File

@ -6,7 +6,7 @@ from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils import to_2tuple
from .base_backbone import BaseBackbone

View File

@ -5,8 +5,8 @@ from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import make_divisible
from mmcls.registry import MODELS
from mmpretrain.models.utils import make_divisible
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -2,7 +2,7 @@
from mmcv.cnn import ConvModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils import InvertedResidual
from .base_backbone import BaseBackbone

View File

@ -9,7 +9,7 @@ from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmengine.model import BaseModule, ModuleList, Sequential
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils.se_layer import SELayer
from .base_backbone import BaseBackbone
@ -302,7 +302,7 @@ class MobileOne(BaseBackbone):
init_cfg (dict or list[dict], optional): Initialization config dict.
Example:
>>> from mmcls.models import MobileOne
>>> from mmpretrain.models import MobileOne
>>> import torch
>>> x = torch.rand(1, 3, 224, 224)
>>> model = MobileOne("s0", out_indices=(0, 1, 2, 3))

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer
from torch import nn
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
from .mobilenet_v2 import InvertedResidual
from .vision_transformer import TransformerEncoderLayer

View File

@ -489,7 +489,7 @@ class MViT(BaseBackbone):
Examples:
>>> import torch
>>> from mmcls.models import build_backbone
>>> from mmpretrain.models import build_backbone
>>>
>>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3])
>>> model = build_backbone(cfg)

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -3,7 +3,7 @@ import numpy as np
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .resnet import ResNet
from .resnext import Bottleneck
@ -43,7 +43,7 @@ class RegNet(ResNet):
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmcls.models import RegNet
>>> from mmpretrain.models import RegNet
>>> import torch
>>> self = RegNet(
arch=dict(

View File

@ -7,7 +7,7 @@ from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -8,8 +8,8 @@ from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
from mmcv.cnn.bricks.transformer import PatchEmbed as _PatchEmbed
from mmengine.model import BaseModule, ModuleList, Sequential
from mmcls.models.utils import SELayer, to_2tuple
from mmcls.registry import MODELS
from mmpretrain.models.utils import SELayer, to_2tuple
from mmpretrain.registry import MODELS
def fuse_bn(conv_or_fc, bn):
@ -88,7 +88,7 @@ class PatchEmbed(_PatchEmbed):
class GlobalPerceptron(SELayer):
"""GlobalPerceptron implemented by using ``mmcls.modes.SELayer``.
"""GlobalPerceptron implemented by using ``mmpretrain.modes.SELayer``.
Args:
input_channels (int): The number of input (and output) channels

View File

@ -8,7 +8,7 @@ from mmengine.model import BaseModule, Sequential
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch import nn
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from ..utils.se_layer import SELayer
from .base_backbone import BaseBackbone

View File

@ -7,7 +7,7 @@ import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import ModuleList, Sequential
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
@ -258,7 +258,7 @@ class Res2Net(ResNet):
Defaults to None.
Example:
>>> from mmcls.models import Res2Net
>>> from mmpretrain.models import Res2Net
>>> import torch
>>> model = Res2Net(depth=50,
... scales=4,

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResLayer, ResNetV1d

View File

@ -9,7 +9,7 @@ from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
eps = 1.0e-5
@ -438,7 +438,7 @@ class ResNet(BaseBackbone):
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmcls.models import ResNet
>>> from mmpretrain.models import ResNet
>>> import torch
>>> self = ResNet(depth=18)
>>> self.eval()

View File

@ -2,7 +2,7 @@
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .resnet import ResNet

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcls.registry import MODELS
from mmpretrain.registry import MODELS
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResLayer, ResNet

Some files were not shown because too many files have changed in this diff Show More