[Improve] Update registries of mmcls. (#1306)

* [Improve] Update registries of mmcls.

* Update according to comments
pull/1317/merge
Ma Zerun 2023-01-11 15:20:51 +08:00 committed by GitHub
parent aa53f7790c
commit 97c4ae8805
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 191 additions and 190 deletions

View File

@ -202,7 +202,7 @@ workflows:
name: minimum_version_cpu
torch: 1.6.0
torchvision: 0.7.0
python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images
python: 3.7.16
requires:
- lint
- build_cpu_with_3rdparty:

View File

@ -5,7 +5,6 @@ from mmengine.fileio import dump
from rich import print_json
from mmcls.apis import inference_model, init_model
from mmcls.utils import register_all_modules
def main():
@ -17,8 +16,6 @@ def main():
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
# register all modules and set mmcls as the default scope.
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
# test a single image

View File

@ -2,7 +2,7 @@
In this section we demonstrate how to prepare an environment with PyTorch.
MMClassification works on Linux, Windows and macOS. It requires Python 3.6+, CUDA 9.2+ and PyTorch 1.6+.
MMClassification works on Linux, Windows and macOS. It requires Python 3.7+, CUDA 9.2+ and PyTorch 1.6+.
```{note}
If you are experienced with PyTorch and have already installed it, just skip this part and jump to the [next section](#installation). Otherwise, you can follow these steps for the preparation.
@ -93,13 +93,9 @@ You will see the output result dict including `pred_label`, `pred_score` and `pr
Option (b). If you install mmcls as a python package, open your python interpreter and copy&paste the following codes.
```python
from mmcls.apis import init_model, inference_model
from mmcls.utils import register_all_modules
from mmcls import get_model, inference_model
config_file = 'resnet50_8xb32_in1k.py'
checkpoint_file = 'resnet50_8xb32_in1k_20210831-ea4938fc.pth'
register_all_modules() # register all modules and set mmcls as the default scope.
model = init_model(config_file, checkpoint_file, device='cpu') # or device='cuda:0'
model = get_model('resnet18_8xb32_in1k', device='cpu') # or device='cuda:0'
inference_model(model, 'demo/demo.JPEG')
```

View File

@ -9,31 +9,23 @@ As for how to test existing models on standard datasets, please see this [guide]
MMClassification provides high-level Python APIs for inference on a given image:
- [init_model](mmcls.apis.init_model): Initialize a model with a config and checkpoint
- [inference_model](mmcls.apis.inference_model): Inference on a given image
- [`get_model`](mmcls.apis.get_model): Get a model with the model name.
- [`init_model`](mmcls.apis.init_model): Initialize a model with a config and checkpoint
- [`inference_model`](mmcls.apis.inference_model): Inference on a given image
Here is an example of building the model and inference on a given image by using ImageNet-1k pre-trained checkpoint.
```{note}
If you use mmcls as a 3rd-party package, you need to download the conifg and the demo image in the example.
Run 'mim download mmcls --config resnet50_8xb32_in1k --dest .' to download the required config.
Run 'wget https://github.com/open-mmlab/mmclassification/blob/master/demo/demo.JPEG' to download the desired demo image.
You can use `wget https://github.com/open-mmlab/mmclassification/raw/master/demo/demo.JPEG` to download the example image or use your own image.
```
```python
from mmcls.apis import inference_model, init_model
from mmcls.utils import register_all_modules
from mmcls import get_model, inference_model
config_path = './configs/resnet/resnet50_8xb32_in1k.py'
checkpoint_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # can be a local path
img_path = 'demo/demo.JPEG' # you can specify your own picture path
img_path = 'demo.JPEG' # you can specify your own picture path
# register all modules and set mmcls as the default scope.
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(config_path, checkpoint_path, device="cpu") # device can be 'cuda:0'
model = get_model('resnet50_8xb32_in1k', pretrained=True, device="cpu") # device can be 'cuda:0'
# test a single image
result = inference_model(model, img_path)
```

View File

@ -2,7 +2,7 @@
在本节中,我们将演示如何准备 PyTorch 相关的依赖环境。
MMClassification 适用于 Linux、Windows 和 macOS。它需要 Python 3.6+、CUDA 9.2+ 和 PyTorch 1.6+。
MMClassification 适用于 Linux、Windows 和 macOS。它需要 Python 3.7+、CUDA 9.2+ 和 PyTorch 1.6+。
```{note}
如果你对配置 PyTorch 环境已经很熟悉,并且已经完成了配置,可以直接进入[下一节](#安装)。
@ -97,13 +97,9 @@ python demo/image_demo.py demo/demo.JPEG resnet50_8xb32_in1k.py resnet50_8xb32_i
如果你是**作为 Python 包安装**,那么可以打开你的 Python 解释器,并粘贴如下代码:
```python
from mmcls.apis import init_model, inference_model
from mmcls.utils import register_all_modules
from mmcls import get_model, inference_model
config_file = 'resnet50_8xb32_in1k.py'
checkpoint_file = 'resnet50_8xb32_in1k_20210831-ea4938fc.pth'
register_all_modules() # 注册所有模块,并将 mmcls 设为默认 scope。
model = init_model(config_file, checkpoint_file, device='cpu') # 或者 device='cuda:0'
model = get_model('resnet18_8xb32_in1k', device='cpu') # 或者 device='cuda:0'
inference_model(model, 'demo/demo.JPEG')
```

View File

@ -9,34 +9,25 @@ MMClassification 在 [Model Zoo](../modelzoo_statistics.md) 中提供了用于
MMClassification 为图像推理提供高级 Python API
- [init_model](mmcls.apis.init_model): 初始化一个模型。
- [inference_model](mmcls.apis.inference_model):对给定图片进行推理。
- [`get_model`](mmcls.apis.get_model): 根据名称获取一个模型。
- [`init_model`](mmcls.apis.init_model): 根据配置文件和权重文件初始化一个模型。
- [`inference_model`](mmcls.apis.inference_model):对给定图片进行推理。
下面是一个示例,如何使用一个 ImageNet-1k 预训练权重初始化模型并推理给定图像。
```{note}
如果您将 mmcls 当作第三方库使用,需要下载样例中的配置文件以及样例图片。
运行 'mim download mmcls --config resnet50_8xb32_in1k --dest .' 下载所需配置文件。
运行 'wget https://github.com/open-mmlab/mmclassification/blob/master/demo/demo.JPEG' 下载所需图片。
可以运行 `wget https://github.com/open-mmlab/mmclassification/raw/master/demo/demo.JPEG` 下载样例图片,或使用其他图片。
```
```python
from mmcls.apis import inference_model, init_model
from mmcls.utils import register_all_modules
from mmcls import get_model, inference_model
config_path = './configs/resnet/resnet50_8xb32_in1k.py'
checkpoint_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # 也可以设置为一个本地的路径
img_path = 'demo/demo.JPEG' # 可以指定自己的图片路径
img_path = 'demo.JPEG' # 可以指定自己的图片路径
# 注册
register_all_modules() # 将所有模块注册在默认 mmcls 域中
# 构建模型
model = init_model(config_path, checkpoint_path, device="cpu") # `device` 可以为 'cuda:0'
model = get_model('resnet50_8xb32_in1k', pretrained=True, device="cpu") # `device` 可以为 'cuda:0'
# 执行推理
result = inference_model(model, img_path)
print(result)
```
`result` 为一个包含了 `pred_label`, `pred_score`, `pred_scores``pred_class`的字典,结果如下:

View File

@ -22,8 +22,6 @@ def inference_model(model: 'BaseModel', img: Union[str, np.ndarray]):
from mmengine.dataset import Compose, default_collate
from mmengine.registry import DefaultScope
import mmcls.datasets # noqa: F401
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline

View File

@ -111,7 +111,6 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None))
import mmcls.models # noqa: F401
from mmcls.registry import MODELS
model = MODELS.build(config.model)

View File

@ -7,7 +7,7 @@ import mmengine
import numpy as np
from mmengine.dataset import BaseDataset as _BaseDataset
from .builder import DATASETS
from mmcls.registry import DATASETS, TRANSFORMS
def expanduser(path):
@ -89,6 +89,13 @@ class BaseDataset(_BaseDataset):
ann_file = expanduser(ann_file)
metainfo = self._compat_classes(metainfo, classes)
transforms = []
for transform in pipeline:
if isinstance(transform, dict):
transforms.append(TRANSFORMS.build(transform))
else:
transforms.append(transform)
super().__init__(
ann_file=ann_file,
metainfo=metainfo,
@ -97,7 +104,7 @@ class BaseDataset(_BaseDataset):
filter_cfg=filter_cfg,
indices=indices,
serialize_data=serialize_data,
pipeline=pipeline,
pipeline=transforms,
test_mode=test_mode,
lazy_init=lazy_init,
max_refetch=max_refetch)

View File

@ -65,8 +65,9 @@ class AutoAugment(RandomChoice):
self.hparams = hparams
self.policies = [[merge_hparams(t, hparams) for t in sub]
for sub in policies]
transforms = [[TRANSFORMS.build(t) for t in sub] for sub in policies]
super().__init__(transforms=self.policies)
super().__init__(transforms=transforms)
def __repr__(self) -> str:
policies_str = ''

View File

@ -67,8 +67,7 @@ class TIMMBackbone(BaseBackbone):
import timm
except ImportError:
raise ImportError(
'Failed to import timm. Please run "pip install timm". '
'"pip install dataclasses" may also be needed for Python 3.6.')
'Failed to import timm. Please run "pip install timm".')
if not isinstance(pretrained, bool):
raise TypeError('pretrained must be bool, not str for model path')

View File

@ -32,66 +32,159 @@ from mmengine.registry import \
from mmengine.registry import Registry
__all__ = [
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', 'HOOKS', 'DATASETS',
'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'MODEL_WRAPPERS',
'WEIGHT_INITIALIZERS', 'BATCH_AUGMENTS', 'OPTIMIZERS', 'OPTIM_WRAPPERS',
'OPTIM_WRAPPER_CONSTRUCTORS', 'PARAM_SCHEDULERS', 'METRICS', 'TASK_UTILS',
'VISUALIZERS', 'VISBACKENDS', 'EVALUATORS', 'LOG_PROCESSORS'
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', 'HOOKS', 'LOG_PROCESSORS',
'OPTIMIZERS', 'OPTIM_WRAPPERS', 'OPTIM_WRAPPER_CONSTRUCTORS',
'PARAM_SCHEDULERS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS',
'MODEL_WRAPPERS', 'WEIGHT_INITIALIZERS', 'BATCH_AUGMENTS', 'TASK_UTILS',
'METRICS', 'EVALUATORS', 'VISUALIZERS', 'VISBACKENDS'
]
# Registries For Runner and the related
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
# manage runner constructors that define how to initialize runners
#######################################################################
# mmcls.engine #
#######################################################################
# Runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry(
'runner',
parent=MMENGINE_RUNNERS,
locations=['mmcls.engine'],
)
# Runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry(
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
# Registries For Data and the related
# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
# manage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
# manage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
# manage all kinds of batch augmentations like Mixup and CutMix.
BATCH_AUGMENTS = Registry('batch augment')
# Registries For Optimizer and the related
# manage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optimizer_wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
# manage constructors that customize the optimization hyperparameters.
'runner constructor',
parent=MMENGINE_RUNNER_CONSTRUCTORS,
locations=['mmcls.engine'],
)
# Loops which define the training or test process, like `EpochBasedTrainLoop`
LOOPS = Registry(
'loop',
parent=MMENGINE_LOOPS,
locations=['mmcls.engine'],
)
# Hooks to add additional functions during running, like `CheckpointHook`
HOOKS = Registry(
'hook',
parent=MMENGINE_HOOKS,
locations=['mmcls.engine'],
)
# Log processors to process the scalar log data.
LOG_PROCESSORS = Registry(
'log processor',
parent=MMENGINE_LOG_PROCESSORS,
locations=['mmcls.engine'],
)
# Optimizers to optimize the model weights, like `SGD` and `Adam`.
OPTIMIZERS = Registry(
'optimizer',
parent=MMENGINE_OPTIMIZERS,
locations=['mmcls.engine'],
)
# Optimizer wrappers to enhance the optimization process.
OPTIM_WRAPPERS = Registry(
'optimizer_wrapper',
parent=MMENGINE_OPTIM_WRAPPERS,
locations=['mmcls.engine'],
)
# Optimizer constructors to customize the hyperparameters of optimizers.
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer wrapper constructor',
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
# manage all kinds of parameter schedulers like `MultiStepLR`
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
locations=['mmcls.engine'],
)
# Parameter schedulers to dynamically adjust optimization parameters.
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
'parameter scheduler',
parent=MMENGINE_PARAM_SCHEDULERS,
locations=['mmcls.engine'],
)
# manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS)
# manage all kinds of evaluators
EVALUATORS = Registry('evaluator', parent=MMENGINE_EVALUATOR)
#######################################################################
# mmcls.datasets #
#######################################################################
# manage task-specific modules like anchor generators and box coders
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
# Datasets like `ImageNet` and `CIFAR10`.
DATASETS = Registry(
'dataset',
parent=MMENGINE_DATASETS,
locations=['mmcls.datasets'],
)
# Samplers to sample the dataset.
DATA_SAMPLERS = Registry(
'data sampler',
parent=MMENGINE_DATA_SAMPLERS,
locations=['mmcls.datasets'],
)
# Transforms to process the samples from the dataset.
TRANSFORMS = Registry(
'transform',
parent=MMENGINE_TRANSFORMS,
locations=['mmcls.datasets'],
)
# Registries For Visualizer and the related
# manage visualizer
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
#######################################################################
# mmcls.models #
#######################################################################
# manage all kinds log processors
LOG_PROCESSORS = Registry('log processor', parent=MMENGINE_LOG_PROCESSORS)
# Neural network modules inheriting `nn.Module`.
MODELS = Registry(
'model',
parent=MMENGINE_MODELS,
locations=['mmcls.models'],
)
# Model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry(
'model_wrapper',
parent=MMENGINE_MODEL_WRAPPERS,
locations=['mmcls.models'],
)
# Weight initialization methods like uniform, xavier.
WEIGHT_INITIALIZERS = Registry(
'weight initializer',
parent=MMENGINE_WEIGHT_INITIALIZERS,
locations=['mmcls.models'],
)
# Batch augmentations like `Mixup` and `CutMix`.
BATCH_AUGMENTS = Registry(
'batch augment',
locations=['mmcls.models'],
)
# Task-specific modules like anchor generators and box coders
TASK_UTILS = Registry(
'task util',
parent=MMENGINE_TASK_UTILS,
locations=['mmcls.models'],
)
#######################################################################
# mmcls.evaluation #
#######################################################################
# Metrics to evaluate the model prediction results.
METRICS = Registry(
'metric',
parent=MMENGINE_METRICS,
locations=['mmcls.evaluation'],
)
# Evaluators to define the evaluation process.
EVALUATORS = Registry(
'evaluator',
parent=MMENGINE_EVALUATOR,
locations=['mmcls.evaluation'],
)
#######################################################################
# mmcls.visualization #
#######################################################################
# Visualizers to display task-specific results.
VISUALIZERS = Registry(
'visualizer',
parent=MMENGINE_VISUALIZERS,
locations=['mmcls.visualization'],
)
# Backends to save the visualization results, like TensorBoard, WandB.
VISBACKENDS = Registry(
'vis_backend',
parent=MMENGINE_VISBACKENDS,
locations=['mmcls.visualization'],
)

View File

@ -169,12 +169,12 @@ if __name__ == '__main__':
keywords='computer vision, image classification',
packages=find_packages(exclude=('configs', 'tools', 'demo')),
include_package_data=True,
python_requires='>=3.7',
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',

View File

@ -11,9 +11,7 @@ import numpy as np
from mmengine.logging import MMLogger
from mmcls.registry import DATASETS, TRANSFORMS
from mmcls.utils import register_all_modules
register_all_modules()
ASSETS_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '../data/dataset'))

View File

@ -7,9 +7,6 @@ from unittest.mock import ANY, patch
import numpy as np
from mmcls.registry import TRANSFORMS
from mmcls.utils import register_all_modules
register_all_modules()
def construct_toy_data():

View File

@ -11,9 +11,6 @@ from PIL import Image
from mmcls.registry import TRANSFORMS
from mmcls.structures import ClsDataSample, MultiTaskDataSample
from mmcls.utils import register_all_modules
register_all_modules()
class TestPackClsInputs(unittest.TestCase):

View File

@ -9,9 +9,6 @@ import mmengine
import numpy as np
from mmcls.registry import TRANSFORMS
from mmcls.utils import register_all_modules
register_all_modules()
def construct_toy_data():

View File

@ -15,9 +15,6 @@ from torch.utils.data import DataLoader, Dataset
from mmcls.registry import HOOKS
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules
register_all_modules()
class ExampleDataset(Dataset):

View File

@ -20,9 +20,6 @@ from mmcls.models import CrossEntropyLoss
from mmcls.models.heads.cls_head import ClsHead
from mmcls.models.losses import LabelSmoothLoss
from mmcls.models.utils.batch_augments import RandomBatchAugment
from mmcls.utils import register_all_modules
register_all_modules()
class SimpleDataPreprocessor(BaseDataPreprocessor):

View File

@ -10,11 +10,8 @@ from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop
from mmcls.engine import VisualizationHook
from mmcls.registry import HOOKS
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules
from mmcls.visualization import ClsVisualizer
register_all_modules()
class TestVisualizationHook(TestCase):

View File

@ -8,9 +8,6 @@ from mmengine.evaluator import Evaluator
from mmcls.evaluation.metrics import AveragePrecision, MultiLabelMetric
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules
register_all_modules()
class TestMultiLabel(TestCase):

View File

@ -7,9 +7,6 @@ import torch
from mmengine.evaluator import Evaluator
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules
register_all_modules()
class TestVOCMultiLabel(TestCase):

View File

@ -10,9 +10,6 @@ from mmengine import ConfigDict
from mmcls.models import ImageClassifier
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules
register_all_modules()
def has_timm() -> bool:

View File

@ -11,9 +11,6 @@ from mmengine import is_seq_of
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample, MultiTaskDataSample
from mmcls.utils import register_all_modules
register_all_modules()
def setup_seed(seed):

View File

@ -14,9 +14,6 @@ from torch.utils.data import DataLoader, Dataset
from mmcls.datasets.transforms import PackClsInputs
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules
register_all_modules()
class ExampleDataset(Dataset):

View File

@ -8,9 +8,6 @@ from mmengine import ConfigDict
from mmcls.models import AverageClsScoreTTA, ImageClassifier
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules
register_all_modules()
class TestAverageClsScoreTTA(TestCase):

View File

@ -6,9 +6,6 @@ import torch
from mmcls.models import ClsDataPreprocessor, RandomBatchAugment
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules
register_all_modules()
class TestClsDataPreprocessor(TestCase):

View File

@ -1,15 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import itertools
import json
from unittest.mock import MagicMock
import mmengine
import rich
from mmengine import Config, DictAction
from mmengine.evaluator import Evaluator
from mmcls.utils import register_all_modules
from mmengine.registry import init_default_scope
def parse_args():
@ -34,20 +30,18 @@ def parse_args():
def main():
args = parse_args()
register_all_modules()
# load config
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
init_default_scope('mmcls') # Use mmcls as default scope.
predictions = mmengine.load(args.pkl_results)
evaluator = Evaluator(cfg.test_evaluator)
# dataset is not needed, use an endless iterator to mock it.
fake_dataset = itertools.repeat({'data_sample': MagicMock()})
eval_results = evaluator.offline_evaluate(fake_dataset, predictions)
rich.print_json(json.dumps(eval_results))
eval_results = evaluator.offline_evaluate(predictions, None)
rich.print(eval_results)
if __name__ == '__main__':

View File

@ -12,8 +12,6 @@ from mmengine.runner import Runner, find_latest_checkpoint
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmcls.utils import register_all_modules
EXP_INFO_FILE = 'kfold_exp.json'
prog_description = """K-Fold cross-validation.
@ -222,10 +220,6 @@ def train_single_fold(cfg, num_splits, fold, resume_ckpt=None):
def main():
args = parse_args()
# register all modules in mmcls into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)

View File

@ -9,8 +9,6 @@ from mmengine.config import Config, ConfigDict, DictAction
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmcls.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(
@ -158,12 +156,10 @@ def merge_args(cfg, args):
def main():
args = parse_args()
# register all modules in mmcls into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
# merge cli arguments to config
cfg = merge_args(cfg, args)
# build the runner from config

View File

@ -9,8 +9,6 @@ from mmengine.runner import Runner
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmcls.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='Train a classifier')
@ -141,10 +139,6 @@ def merge_args(cfg, args):
def main():
args = parse_args()
# register all modules in mmcls into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)

View File

@ -8,11 +8,11 @@ import mmcv
import numpy as np
from mmengine.config import Config, DictAction
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar
from mmengine.visualization import Visualizer
from mmcls.datasets.builder import build_dataset
from mmcls.utils import register_all_modules
from mmcls.visualization import ClsVisualizer
from mmcls.visualization.cls_visualizer import _get_adaptive_scale
@ -170,8 +170,7 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# register all modules in mmcls into the registries
register_all_modules()
init_default_scope('mmcls') # Use mmcls as default scope.
dataset_cfg = cfg.get(args.phase + '_dataloader').get('dataset')
dataset = build_dataset(dataset_cfg)

View File

@ -11,12 +11,12 @@ import numpy as np
from mmcv.transforms import Compose
from mmengine.config import Config, DictAction
from mmengine.dataset import default_collate
from mmengine.registry import init_default_scope
from mmengine.utils import to_2tuple
from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm
from mmcls import digit_version
from mmcls.apis import init_model
from mmcls.utils import register_all_modules
try:
from pytorch_grad_cam import (EigenCAM, EigenGradCAM, GradCAM,
@ -265,7 +265,7 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
register_all_modules()
init_default_scope('mmcls')
# build the model from a config file and a checkpoint file
model = init_model(cfg, args.checkpoint, device=args.device)
if args.preview_model:

View File

@ -16,8 +16,6 @@ from mmengine.runner import Runner
from mmengine.visualization import Visualizer
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn
from mmcls.utils import register_all_modules
class SimpleModel(BaseModel):
"""simple model that do nothing in train_step."""
@ -214,8 +212,6 @@ def main():
osp.splitext(osp.basename(args.config))[0])
cfg.log_level = args.log_level
# register all modules in mmcls into the registries
register_all_modules()
# make sure save_root exists
if args.save_path and not args.save_path.parent.exists():