[Improve] Update registries of mmcls. (#1306)
* [Improve] Update registries of mmcls. * Update according to commentspull/1317/merge
parent
aa53f7790c
commit
97c4ae8805
|
@ -202,7 +202,7 @@ workflows:
|
|||
name: minimum_version_cpu
|
||||
torch: 1.6.0
|
||||
torchvision: 0.7.0
|
||||
python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images
|
||||
python: 3.7.16
|
||||
requires:
|
||||
- lint
|
||||
- build_cpu_with_3rdparty:
|
||||
|
|
|
@ -5,7 +5,6 @@ from mmengine.fileio import dump
|
|||
from rich import print_json
|
||||
|
||||
from mmcls.apis import inference_model, init_model
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -17,8 +16,6 @@ def main():
|
|||
'--device', default='cuda:0', help='Device used for inference')
|
||||
args = parser.parse_args()
|
||||
|
||||
# register all modules and set mmcls as the default scope.
|
||||
register_all_modules()
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||
# test a single image
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
In this section we demonstrate how to prepare an environment with PyTorch.
|
||||
|
||||
MMClassification works on Linux, Windows and macOS. It requires Python 3.6+, CUDA 9.2+ and PyTorch 1.6+.
|
||||
MMClassification works on Linux, Windows and macOS. It requires Python 3.7+, CUDA 9.2+ and PyTorch 1.6+.
|
||||
|
||||
```{note}
|
||||
If you are experienced with PyTorch and have already installed it, just skip this part and jump to the [next section](#installation). Otherwise, you can follow these steps for the preparation.
|
||||
|
@ -93,13 +93,9 @@ You will see the output result dict including `pred_label`, `pred_score` and `pr
|
|||
Option (b). If you install mmcls as a python package, open your python interpreter and copy&paste the following codes.
|
||||
|
||||
```python
|
||||
from mmcls.apis import init_model, inference_model
|
||||
from mmcls.utils import register_all_modules
|
||||
from mmcls import get_model, inference_model
|
||||
|
||||
config_file = 'resnet50_8xb32_in1k.py'
|
||||
checkpoint_file = 'resnet50_8xb32_in1k_20210831-ea4938fc.pth'
|
||||
register_all_modules() # register all modules and set mmcls as the default scope.
|
||||
model = init_model(config_file, checkpoint_file, device='cpu') # or device='cuda:0'
|
||||
model = get_model('resnet18_8xb32_in1k', device='cpu') # or device='cuda:0'
|
||||
inference_model(model, 'demo/demo.JPEG')
|
||||
```
|
||||
|
||||
|
|
|
@ -9,31 +9,23 @@ As for how to test existing models on standard datasets, please see this [guide]
|
|||
|
||||
MMClassification provides high-level Python APIs for inference on a given image:
|
||||
|
||||
- [init_model](mmcls.apis.init_model): Initialize a model with a config and checkpoint
|
||||
- [inference_model](mmcls.apis.inference_model): Inference on a given image
|
||||
- [`get_model`](mmcls.apis.get_model): Get a model with the model name.
|
||||
- [`init_model`](mmcls.apis.init_model): Initialize a model with a config and checkpoint
|
||||
- [`inference_model`](mmcls.apis.inference_model): Inference on a given image
|
||||
|
||||
Here is an example of building the model and inference on a given image by using ImageNet-1k pre-trained checkpoint.
|
||||
|
||||
```{note}
|
||||
If you use mmcls as a 3rd-party package, you need to download the conifg and the demo image in the example.
|
||||
|
||||
Run 'mim download mmcls --config resnet50_8xb32_in1k --dest .' to download the required config.
|
||||
|
||||
Run 'wget https://github.com/open-mmlab/mmclassification/blob/master/demo/demo.JPEG' to download the desired demo image.
|
||||
You can use `wget https://github.com/open-mmlab/mmclassification/raw/master/demo/demo.JPEG` to download the example image or use your own image.
|
||||
```
|
||||
|
||||
```python
|
||||
from mmcls.apis import inference_model, init_model
|
||||
from mmcls.utils import register_all_modules
|
||||
from mmcls import get_model, inference_model
|
||||
|
||||
config_path = './configs/resnet/resnet50_8xb32_in1k.py'
|
||||
checkpoint_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # can be a local path
|
||||
img_path = 'demo/demo.JPEG' # you can specify your own picture path
|
||||
img_path = 'demo.JPEG' # you can specify your own picture path
|
||||
|
||||
# register all modules and set mmcls as the default scope.
|
||||
register_all_modules()
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(config_path, checkpoint_path, device="cpu") # device can be 'cuda:0'
|
||||
model = get_model('resnet50_8xb32_in1k', pretrained=True, device="cpu") # device can be 'cuda:0'
|
||||
# test a single image
|
||||
result = inference_model(model, img_path)
|
||||
```
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
在本节中,我们将演示如何准备 PyTorch 相关的依赖环境。
|
||||
|
||||
MMClassification 适用于 Linux、Windows 和 macOS。它需要 Python 3.6+、CUDA 9.2+ 和 PyTorch 1.6+。
|
||||
MMClassification 适用于 Linux、Windows 和 macOS。它需要 Python 3.7+、CUDA 9.2+ 和 PyTorch 1.6+。
|
||||
|
||||
```{note}
|
||||
如果你对配置 PyTorch 环境已经很熟悉,并且已经完成了配置,可以直接进入[下一节](#安装)。
|
||||
|
@ -97,13 +97,9 @@ python demo/image_demo.py demo/demo.JPEG resnet50_8xb32_in1k.py resnet50_8xb32_i
|
|||
如果你是**作为 Python 包安装**,那么可以打开你的 Python 解释器,并粘贴如下代码:
|
||||
|
||||
```python
|
||||
from mmcls.apis import init_model, inference_model
|
||||
from mmcls.utils import register_all_modules
|
||||
from mmcls import get_model, inference_model
|
||||
|
||||
config_file = 'resnet50_8xb32_in1k.py'
|
||||
checkpoint_file = 'resnet50_8xb32_in1k_20210831-ea4938fc.pth'
|
||||
register_all_modules() # 注册所有模块,并将 mmcls 设为默认 scope。
|
||||
model = init_model(config_file, checkpoint_file, device='cpu') # 或者 device='cuda:0'
|
||||
model = get_model('resnet18_8xb32_in1k', device='cpu') # 或者 device='cuda:0'
|
||||
inference_model(model, 'demo/demo.JPEG')
|
||||
```
|
||||
|
||||
|
|
|
@ -9,34 +9,25 @@ MMClassification 在 [Model Zoo](../modelzoo_statistics.md) 中提供了用于
|
|||
|
||||
MMClassification 为图像推理提供高级 Python API:
|
||||
|
||||
- [init_model](mmcls.apis.init_model): 初始化一个模型。
|
||||
- [inference_model](mmcls.apis.inference_model):对给定图片进行推理。
|
||||
- [`get_model`](mmcls.apis.get_model): 根据名称获取一个模型。
|
||||
- [`init_model`](mmcls.apis.init_model): 根据配置文件和权重文件初始化一个模型。
|
||||
- [`inference_model`](mmcls.apis.inference_model):对给定图片进行推理。
|
||||
|
||||
下面是一个示例,如何使用一个 ImageNet-1k 预训练权重初始化模型并推理给定图像。
|
||||
|
||||
```{note}
|
||||
如果您将 mmcls 当作第三方库使用,需要下载样例中的配置文件以及样例图片。
|
||||
|
||||
运行 'mim download mmcls --config resnet50_8xb32_in1k --dest .' 下载所需配置文件。
|
||||
|
||||
运行 'wget https://github.com/open-mmlab/mmclassification/blob/master/demo/demo.JPEG' 下载所需图片。
|
||||
可以运行 `wget https://github.com/open-mmlab/mmclassification/raw/master/demo/demo.JPEG` 下载样例图片,或使用其他图片。
|
||||
```
|
||||
|
||||
```python
|
||||
from mmcls.apis import inference_model, init_model
|
||||
from mmcls.utils import register_all_modules
|
||||
from mmcls import get_model, inference_model
|
||||
|
||||
config_path = './configs/resnet/resnet50_8xb32_in1k.py'
|
||||
checkpoint_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # 也可以设置为一个本地的路径
|
||||
img_path = 'demo/demo.JPEG' # 可以指定自己的图片路径
|
||||
img_path = 'demo.JPEG' # 可以指定自己的图片路径
|
||||
|
||||
# 注册
|
||||
register_all_modules() # 将所有模块注册在默认 mmcls 域中
|
||||
# 构建模型
|
||||
model = init_model(config_path, checkpoint_path, device="cpu") # `device` 可以为 'cuda:0'
|
||||
model = get_model('resnet50_8xb32_in1k', pretrained=True, device="cpu") # `device` 可以为 'cuda:0'
|
||||
# 执行推理
|
||||
result = inference_model(model, img_path)
|
||||
print(result)
|
||||
```
|
||||
|
||||
`result` 为一个包含了 `pred_label`, `pred_score`, `pred_scores` 和 `pred_class`的字典,结果如下:
|
||||
|
|
|
@ -22,8 +22,6 @@ def inference_model(model: 'BaseModel', img: Union[str, np.ndarray]):
|
|||
from mmengine.dataset import Compose, default_collate
|
||||
from mmengine.registry import DefaultScope
|
||||
|
||||
import mmcls.datasets # noqa: F401
|
||||
|
||||
cfg = model.cfg
|
||||
# build the data pipeline
|
||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
|
|
|
@ -111,7 +111,6 @@ def init_model(config, checkpoint=None, device=None, **kwargs):
|
|||
config.model.setdefault('data_preprocessor',
|
||||
config.get('data_preprocessor', None))
|
||||
|
||||
import mmcls.models # noqa: F401
|
||||
from mmcls.registry import MODELS
|
||||
|
||||
model = MODELS.build(config.model)
|
||||
|
|
|
@ -7,7 +7,7 @@ import mmengine
|
|||
import numpy as np
|
||||
from mmengine.dataset import BaseDataset as _BaseDataset
|
||||
|
||||
from .builder import DATASETS
|
||||
from mmcls.registry import DATASETS, TRANSFORMS
|
||||
|
||||
|
||||
def expanduser(path):
|
||||
|
@ -89,6 +89,13 @@ class BaseDataset(_BaseDataset):
|
|||
ann_file = expanduser(ann_file)
|
||||
metainfo = self._compat_classes(metainfo, classes)
|
||||
|
||||
transforms = []
|
||||
for transform in pipeline:
|
||||
if isinstance(transform, dict):
|
||||
transforms.append(TRANSFORMS.build(transform))
|
||||
else:
|
||||
transforms.append(transform)
|
||||
|
||||
super().__init__(
|
||||
ann_file=ann_file,
|
||||
metainfo=metainfo,
|
||||
|
@ -97,7 +104,7 @@ class BaseDataset(_BaseDataset):
|
|||
filter_cfg=filter_cfg,
|
||||
indices=indices,
|
||||
serialize_data=serialize_data,
|
||||
pipeline=pipeline,
|
||||
pipeline=transforms,
|
||||
test_mode=test_mode,
|
||||
lazy_init=lazy_init,
|
||||
max_refetch=max_refetch)
|
||||
|
|
|
@ -65,8 +65,9 @@ class AutoAugment(RandomChoice):
|
|||
self.hparams = hparams
|
||||
self.policies = [[merge_hparams(t, hparams) for t in sub]
|
||||
for sub in policies]
|
||||
transforms = [[TRANSFORMS.build(t) for t in sub] for sub in policies]
|
||||
|
||||
super().__init__(transforms=self.policies)
|
||||
super().__init__(transforms=transforms)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
policies_str = ''
|
||||
|
|
|
@ -67,8 +67,7 @@ class TIMMBackbone(BaseBackbone):
|
|||
import timm
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Failed to import timm. Please run "pip install timm". '
|
||||
'"pip install dataclasses" may also be needed for Python 3.6.')
|
||||
'Failed to import timm. Please run "pip install timm".')
|
||||
|
||||
if not isinstance(pretrained, bool):
|
||||
raise TypeError('pretrained must be bool, not str for model path')
|
||||
|
|
|
@ -32,66 +32,159 @@ from mmengine.registry import \
|
|||
from mmengine.registry import Registry
|
||||
|
||||
__all__ = [
|
||||
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', 'HOOKS', 'DATASETS',
|
||||
'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'MODEL_WRAPPERS',
|
||||
'WEIGHT_INITIALIZERS', 'BATCH_AUGMENTS', 'OPTIMIZERS', 'OPTIM_WRAPPERS',
|
||||
'OPTIM_WRAPPER_CONSTRUCTORS', 'PARAM_SCHEDULERS', 'METRICS', 'TASK_UTILS',
|
||||
'VISUALIZERS', 'VISBACKENDS', 'EVALUATORS', 'LOG_PROCESSORS'
|
||||
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', 'HOOKS', 'LOG_PROCESSORS',
|
||||
'OPTIMIZERS', 'OPTIM_WRAPPERS', 'OPTIM_WRAPPER_CONSTRUCTORS',
|
||||
'PARAM_SCHEDULERS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS',
|
||||
'MODEL_WRAPPERS', 'WEIGHT_INITIALIZERS', 'BATCH_AUGMENTS', 'TASK_UTILS',
|
||||
'METRICS', 'EVALUATORS', 'VISUALIZERS', 'VISBACKENDS'
|
||||
]
|
||||
|
||||
# Registries For Runner and the related
|
||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
|
||||
# manage runner constructors that define how to initialize runners
|
||||
#######################################################################
|
||||
# mmcls.engine #
|
||||
#######################################################################
|
||||
|
||||
# Runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||
RUNNERS = Registry(
|
||||
'runner',
|
||||
parent=MMENGINE_RUNNERS,
|
||||
locations=['mmcls.engine'],
|
||||
)
|
||||
# Runner constructors that define how to initialize runners
|
||||
RUNNER_CONSTRUCTORS = Registry(
|
||||
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
|
||||
# manage all kinds of loops like `EpochBasedTrainLoop`
|
||||
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
|
||||
# manage all kinds of hooks like `CheckpointHook`
|
||||
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
|
||||
|
||||
# Registries For Data and the related
|
||||
# manage data-related modules
|
||||
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
|
||||
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
|
||||
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
|
||||
|
||||
# manage all kinds of modules inheriting `nn.Module`
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS)
|
||||
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
|
||||
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
|
||||
# manage all kinds of weight initialization modules like `Uniform`
|
||||
WEIGHT_INITIALIZERS = Registry(
|
||||
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
|
||||
# manage all kinds of batch augmentations like Mixup and CutMix.
|
||||
BATCH_AUGMENTS = Registry('batch augment')
|
||||
|
||||
# Registries For Optimizer and the related
|
||||
# manage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
|
||||
# manage optimizer wrapper
|
||||
OPTIM_WRAPPERS = Registry('optimizer_wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
'runner constructor',
|
||||
parent=MMENGINE_RUNNER_CONSTRUCTORS,
|
||||
locations=['mmcls.engine'],
|
||||
)
|
||||
# Loops which define the training or test process, like `EpochBasedTrainLoop`
|
||||
LOOPS = Registry(
|
||||
'loop',
|
||||
parent=MMENGINE_LOOPS,
|
||||
locations=['mmcls.engine'],
|
||||
)
|
||||
# Hooks to add additional functions during running, like `CheckpointHook`
|
||||
HOOKS = Registry(
|
||||
'hook',
|
||||
parent=MMENGINE_HOOKS,
|
||||
locations=['mmcls.engine'],
|
||||
)
|
||||
# Log processors to process the scalar log data.
|
||||
LOG_PROCESSORS = Registry(
|
||||
'log processor',
|
||||
parent=MMENGINE_LOG_PROCESSORS,
|
||||
locations=['mmcls.engine'],
|
||||
)
|
||||
# Optimizers to optimize the model weights, like `SGD` and `Adam`.
|
||||
OPTIMIZERS = Registry(
|
||||
'optimizer',
|
||||
parent=MMENGINE_OPTIMIZERS,
|
||||
locations=['mmcls.engine'],
|
||||
)
|
||||
# Optimizer wrappers to enhance the optimization process.
|
||||
OPTIM_WRAPPERS = Registry(
|
||||
'optimizer_wrapper',
|
||||
parent=MMENGINE_OPTIM_WRAPPERS,
|
||||
locations=['mmcls.engine'],
|
||||
)
|
||||
# Optimizer constructors to customize the hyperparameters of optimizers.
|
||||
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
|
||||
'optimizer wrapper constructor',
|
||||
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
|
||||
# manage all kinds of parameter schedulers like `MultiStepLR`
|
||||
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
|
||||
locations=['mmcls.engine'],
|
||||
)
|
||||
# Parameter schedulers to dynamically adjust optimization parameters.
|
||||
PARAM_SCHEDULERS = Registry(
|
||||
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
|
||||
'parameter scheduler',
|
||||
parent=MMENGINE_PARAM_SCHEDULERS,
|
||||
locations=['mmcls.engine'],
|
||||
)
|
||||
|
||||
# manage all kinds of metrics
|
||||
METRICS = Registry('metric', parent=MMENGINE_METRICS)
|
||||
# manage all kinds of evaluators
|
||||
EVALUATORS = Registry('evaluator', parent=MMENGINE_EVALUATOR)
|
||||
#######################################################################
|
||||
# mmcls.datasets #
|
||||
#######################################################################
|
||||
|
||||
# manage task-specific modules like anchor generators and box coders
|
||||
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
|
||||
# Datasets like `ImageNet` and `CIFAR10`.
|
||||
DATASETS = Registry(
|
||||
'dataset',
|
||||
parent=MMENGINE_DATASETS,
|
||||
locations=['mmcls.datasets'],
|
||||
)
|
||||
# Samplers to sample the dataset.
|
||||
DATA_SAMPLERS = Registry(
|
||||
'data sampler',
|
||||
parent=MMENGINE_DATA_SAMPLERS,
|
||||
locations=['mmcls.datasets'],
|
||||
)
|
||||
# Transforms to process the samples from the dataset.
|
||||
TRANSFORMS = Registry(
|
||||
'transform',
|
||||
parent=MMENGINE_TRANSFORMS,
|
||||
locations=['mmcls.datasets'],
|
||||
)
|
||||
|
||||
# Registries For Visualizer and the related
|
||||
# manage visualizer
|
||||
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
|
||||
# manage visualizer backend
|
||||
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
|
||||
#######################################################################
|
||||
# mmcls.models #
|
||||
#######################################################################
|
||||
|
||||
# manage all kinds log processors
|
||||
LOG_PROCESSORS = Registry('log processor', parent=MMENGINE_LOG_PROCESSORS)
|
||||
# Neural network modules inheriting `nn.Module`.
|
||||
MODELS = Registry(
|
||||
'model',
|
||||
parent=MMENGINE_MODELS,
|
||||
locations=['mmcls.models'],
|
||||
)
|
||||
# Model wrappers like 'MMDistributedDataParallel'
|
||||
MODEL_WRAPPERS = Registry(
|
||||
'model_wrapper',
|
||||
parent=MMENGINE_MODEL_WRAPPERS,
|
||||
locations=['mmcls.models'],
|
||||
)
|
||||
# Weight initialization methods like uniform, xavier.
|
||||
WEIGHT_INITIALIZERS = Registry(
|
||||
'weight initializer',
|
||||
parent=MMENGINE_WEIGHT_INITIALIZERS,
|
||||
locations=['mmcls.models'],
|
||||
)
|
||||
# Batch augmentations like `Mixup` and `CutMix`.
|
||||
BATCH_AUGMENTS = Registry(
|
||||
'batch augment',
|
||||
locations=['mmcls.models'],
|
||||
)
|
||||
# Task-specific modules like anchor generators and box coders
|
||||
TASK_UTILS = Registry(
|
||||
'task util',
|
||||
parent=MMENGINE_TASK_UTILS,
|
||||
locations=['mmcls.models'],
|
||||
)
|
||||
|
||||
#######################################################################
|
||||
# mmcls.evaluation #
|
||||
#######################################################################
|
||||
|
||||
# Metrics to evaluate the model prediction results.
|
||||
METRICS = Registry(
|
||||
'metric',
|
||||
parent=MMENGINE_METRICS,
|
||||
locations=['mmcls.evaluation'],
|
||||
)
|
||||
# Evaluators to define the evaluation process.
|
||||
EVALUATORS = Registry(
|
||||
'evaluator',
|
||||
parent=MMENGINE_EVALUATOR,
|
||||
locations=['mmcls.evaluation'],
|
||||
)
|
||||
|
||||
#######################################################################
|
||||
# mmcls.visualization #
|
||||
#######################################################################
|
||||
|
||||
# Visualizers to display task-specific results.
|
||||
VISUALIZERS = Registry(
|
||||
'visualizer',
|
||||
parent=MMENGINE_VISUALIZERS,
|
||||
locations=['mmcls.visualization'],
|
||||
)
|
||||
# Backends to save the visualization results, like TensorBoard, WandB.
|
||||
VISBACKENDS = Registry(
|
||||
'vis_backend',
|
||||
parent=MMENGINE_VISBACKENDS,
|
||||
locations=['mmcls.visualization'],
|
||||
)
|
||||
|
|
2
setup.py
2
setup.py
|
@ -169,12 +169,12 @@ if __name__ == '__main__':
|
|||
keywords='computer vision, image classification',
|
||||
packages=find_packages(exclude=('configs', 'tools', 'demo')),
|
||||
include_package_data=True,
|
||||
python_requires='>=3.7',
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: OS Independent',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
|
|
|
@ -11,9 +11,7 @@ import numpy as np
|
|||
from mmengine.logging import MMLogger
|
||||
|
||||
from mmcls.registry import DATASETS, TRANSFORMS
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
ASSETS_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '../data/dataset'))
|
||||
|
||||
|
||||
|
|
|
@ -7,9 +7,6 @@ from unittest.mock import ANY, patch
|
|||
import numpy as np
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
def construct_toy_data():
|
||||
|
|
|
@ -11,9 +11,6 @@ from PIL import Image
|
|||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from mmcls.structures import ClsDataSample, MultiTaskDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestPackClsInputs(unittest.TestCase):
|
||||
|
|
|
@ -9,9 +9,6 @@ import mmengine
|
|||
import numpy as np
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
def construct_toy_data():
|
||||
|
|
|
@ -15,9 +15,6 @@ from torch.utils.data import DataLoader, Dataset
|
|||
|
||||
from mmcls.registry import HOOKS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
|
|
@ -20,9 +20,6 @@ from mmcls.models import CrossEntropyLoss
|
|||
from mmcls.models.heads.cls_head import ClsHead
|
||||
from mmcls.models.losses import LabelSmoothLoss
|
||||
from mmcls.models.utils.batch_augments import RandomBatchAugment
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class SimpleDataPreprocessor(BaseDataPreprocessor):
|
||||
|
|
|
@ -10,11 +10,8 @@ from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop
|
|||
from mmcls.engine import VisualizationHook
|
||||
from mmcls.registry import HOOKS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
from mmcls.visualization import ClsVisualizer
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestVisualizationHook(TestCase):
|
||||
|
||||
|
|
|
@ -8,9 +8,6 @@ from mmengine.evaluator import Evaluator
|
|||
|
||||
from mmcls.evaluation.metrics import AveragePrecision, MultiLabelMetric
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestMultiLabel(TestCase):
|
||||
|
|
|
@ -7,9 +7,6 @@ import torch
|
|||
from mmengine.evaluator import Evaluator
|
||||
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestVOCMultiLabel(TestCase):
|
||||
|
|
|
@ -10,9 +10,6 @@ from mmengine import ConfigDict
|
|||
from mmcls.models import ImageClassifier
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
def has_timm() -> bool:
|
||||
|
|
|
@ -11,9 +11,6 @@ from mmengine import is_seq_of
|
|||
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import ClsDataSample, MultiTaskDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
|
|
|
@ -14,9 +14,6 @@ from torch.utils.data import DataLoader, Dataset
|
|||
from mmcls.datasets.transforms import PackClsInputs
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
|
|
@ -8,9 +8,6 @@ from mmengine import ConfigDict
|
|||
from mmcls.models import AverageClsScoreTTA, ImageClassifier
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestAverageClsScoreTTA(TestCase):
|
||||
|
|
|
@ -6,9 +6,6 @@ import torch
|
|||
from mmcls.models import ClsDataPreprocessor, RandomBatchAugment
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestClsDataPreprocessor(TestCase):
|
||||
|
|
|
@ -1,15 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import mmengine
|
||||
import rich
|
||||
from mmengine import Config, DictAction
|
||||
from mmengine.evaluator import Evaluator
|
||||
|
||||
from mmcls.utils import register_all_modules
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -34,20 +30,18 @@ def parse_args():
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
register_all_modules()
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
init_default_scope('mmcls') # Use mmcls as default scope.
|
||||
|
||||
predictions = mmengine.load(args.pkl_results)
|
||||
|
||||
evaluator = Evaluator(cfg.test_evaluator)
|
||||
# dataset is not needed, use an endless iterator to mock it.
|
||||
fake_dataset = itertools.repeat({'data_sample': MagicMock()})
|
||||
eval_results = evaluator.offline_evaluate(fake_dataset, predictions)
|
||||
rich.print_json(json.dumps(eval_results))
|
||||
eval_results = evaluator.offline_evaluate(predictions, None)
|
||||
rich.print(eval_results)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -12,8 +12,6 @@ from mmengine.runner import Runner, find_latest_checkpoint
|
|||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
EXP_INFO_FILE = 'kfold_exp.json'
|
||||
|
||||
prog_description = """K-Fold cross-validation.
|
||||
|
@ -222,10 +220,6 @@ def train_single_fold(cfg, num_splits, fold, resume_ckpt=None):
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmcls into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
|
|
|
@ -9,8 +9,6 @@ from mmengine.config import Config, ConfigDict, DictAction
|
|||
from mmengine.hooks import Hook
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -158,12 +156,10 @@ def merge_args(cfg, args):
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmcls into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
# merge cli arguments to config
|
||||
cfg = merge_args(cfg, args)
|
||||
|
||||
# build the runner from config
|
||||
|
|
|
@ -9,8 +9,6 @@ from mmengine.runner import Runner
|
|||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train a classifier')
|
||||
|
@ -141,10 +139,6 @@ def merge_args(cfg, args):
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmcls into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
|
|
|
@ -8,11 +8,11 @@ import mmcv
|
|||
import numpy as np
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.utils import ProgressBar
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmcls.datasets.builder import build_dataset
|
||||
from mmcls.utils import register_all_modules
|
||||
from mmcls.visualization import ClsVisualizer
|
||||
from mmcls.visualization.cls_visualizer import _get_adaptive_scale
|
||||
|
||||
|
@ -170,8 +170,7 @@ def main():
|
|||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# register all modules in mmcls into the registries
|
||||
register_all_modules()
|
||||
init_default_scope('mmcls') # Use mmcls as default scope.
|
||||
|
||||
dataset_cfg = cfg.get(args.phase + '_dataloader').get('dataset')
|
||||
dataset = build_dataset(dataset_cfg)
|
||||
|
|
|
@ -11,12 +11,12 @@ import numpy as np
|
|||
from mmcv.transforms import Compose
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.dataset import default_collate
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.utils import to_2tuple
|
||||
from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm
|
||||
|
||||
from mmcls import digit_version
|
||||
from mmcls.apis import init_model
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
try:
|
||||
from pytorch_grad_cam import (EigenCAM, EigenGradCAM, GradCAM,
|
||||
|
@ -265,7 +265,7 @@ def main():
|
|||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
register_all_modules()
|
||||
init_default_scope('mmcls')
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(cfg, args.checkpoint, device=args.device)
|
||||
if args.preview_model:
|
||||
|
|
|
@ -16,8 +16,6 @@ from mmengine.runner import Runner
|
|||
from mmengine.visualization import Visualizer
|
||||
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn
|
||||
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
|
||||
class SimpleModel(BaseModel):
|
||||
"""simple model that do nothing in train_step."""
|
||||
|
@ -214,8 +212,6 @@ def main():
|
|||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
cfg.log_level = args.log_level
|
||||
# register all modules in mmcls into the registries
|
||||
register_all_modules()
|
||||
|
||||
# make sure save_root exists
|
||||
if args.save_path and not args.save_path.parent.exists():
|
||||
|
|
Loading…
Reference in New Issue