[Feature] Support getting model from the name defined in the model-index file. (#1236)
* [Feature] Support getting model from the name defined in the model-index file. * Add unit tests. * Prevent import `timm` if the `TIMMBackbone` is not used. * Fix Windows CI. * Move `init_model` to `mmcls.apis.hub`, and support pass nn.Module to all model components. * Fix requirements * Rename `hub.py` to `model.py` and add unit tests.pull/1243/head
parent
d990982fc0
commit
c127c474b9
|
@ -13,12 +13,17 @@ These are some high-level APIs for classification tasks.
|
|||
:local:
|
||||
:backlinks: top
|
||||
|
||||
Model
|
||||
------------------
|
||||
|
||||
.. autofunction:: list_models
|
||||
|
||||
.. autofunction:: get_model
|
||||
|
||||
.. autofunction:: init_model
|
||||
|
||||
|
||||
Inference
|
||||
------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
init_model
|
||||
inference_model
|
||||
.. autofunction:: inference_model
|
||||
|
|
|
@ -3,6 +3,7 @@ import mmcv
|
|||
import mmengine
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
from .apis import * # noqa: F401, F403
|
||||
from .version import __version__
|
||||
|
||||
mmcv_minimum_version = '2.0.0rc1'
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inference import inference_model, init_model
|
||||
from .inference import inference_model
|
||||
from .model import ModelHub, get_model, init_model, list_models
|
||||
|
||||
__all__ = ['init_model', 'inference_model']
|
||||
__all__ = [
|
||||
'init_model', 'inference_model', 'list_models', 'get_model', 'ModelHub'
|
||||
]
|
||||
|
|
|
@ -1,62 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.config import Config
|
||||
from mmengine.dataset import Compose, pseudo_collate
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.dataset import Compose, default_collate
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.registry import DefaultScope
|
||||
|
||||
from mmcls.models import build_classifier
|
||||
from mmcls.utils import register_all_modules
|
||||
import mmcls.datasets # noqa: F401
|
||||
|
||||
|
||||
def init_model(config, checkpoint=None, device='cuda:0', options=None):
|
||||
"""Initialize a classifier from config file.
|
||||
|
||||
Args:
|
||||
config (str or :obj:`mmengine.Config`): Config file path or the config
|
||||
object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
options (dict): Options to override some settings in the used config.
|
||||
|
||||
Returns:
|
||||
nn.Module: The constructed classifier.
|
||||
"""
|
||||
register_all_modules()
|
||||
if isinstance(config, str):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(config)}')
|
||||
if options is not None:
|
||||
config.merge_from_dict(options)
|
||||
config.model.setdefault('data_preprocessor',
|
||||
config.get('data_preprocessor', None))
|
||||
model = build_classifier(config.model)
|
||||
if checkpoint is not None:
|
||||
# Mapping the weights to GPU may cause unexpected video memory leak
|
||||
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
if 'dataset_meta' in checkpoint.get('meta', {}):
|
||||
# mmcls 1.x
|
||||
model.CLASSES = checkpoint['meta']['dataset_meta']['classes']
|
||||
elif 'CLASSES' in checkpoint.get('meta', {}):
|
||||
# mmcls < 1.x
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
else:
|
||||
from mmcls.datasets.categories import IMAGENET_CATEGORIES
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn('Class names are not saved in the checkpoint\'s '
|
||||
'meta data, use imagenet by default.')
|
||||
model.CLASSES = IMAGENET_CATEGORIES
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def inference_model(model, img):
|
||||
def inference_model(model: BaseModel, img: Union[str, np.ndarray]):
|
||||
"""Inference image(s) with the classifier.
|
||||
|
||||
Args:
|
||||
|
@ -67,7 +21,6 @@ def inference_model(model, img):
|
|||
result (dict): The classification results that contains
|
||||
`class_name`, `pred_label` and `pred_score`.
|
||||
"""
|
||||
register_all_modules()
|
||||
cfg = model.cfg
|
||||
# build the data pipeline
|
||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
|
@ -79,9 +32,10 @@ def inference_model(model, img):
|
|||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
||||
test_pipeline_cfg.pop(0)
|
||||
data = dict(img=img)
|
||||
test_pipeline = Compose(test_pipeline_cfg)
|
||||
with DefaultScope.overwrite_default_scope('mmcls'):
|
||||
test_pipeline = Compose(test_pipeline_cfg)
|
||||
data = test_pipeline(data)
|
||||
data = pseudo_collate([data])
|
||||
data = default_collate([data])
|
||||
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -0,0 +1,220 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import fnmatch
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import get_installed_path
|
||||
from modelindex.load_model_index import load
|
||||
from modelindex.models.Model import Model
|
||||
|
||||
import mmcls.models # noqa: F401
|
||||
from mmcls.registry import MODELS
|
||||
|
||||
|
||||
class ModelHub:
|
||||
"""A hub to host the meta information of all pre-defined models."""
|
||||
_models_dict = {}
|
||||
|
||||
@classmethod
|
||||
def register_model_index(cls,
|
||||
model_index_path: Union[str, PathLike],
|
||||
config_prefix: Union[str, PathLike, None] = None):
|
||||
"""Parse the model-index file and register all models.
|
||||
|
||||
Args:
|
||||
model_index_path (str | PathLike): The path of the model-index
|
||||
file.
|
||||
config_prefix (str | PathLike | None): The prefix of all config
|
||||
file paths in the model-index file.
|
||||
"""
|
||||
model_index = load(str(model_index_path))
|
||||
model_index.build_models_with_collections()
|
||||
|
||||
for metainfo in model_index.models:
|
||||
model_name = metainfo.name.lower()
|
||||
if metainfo.name in cls._models_dict:
|
||||
raise ValueError(
|
||||
'The model name {} is conflict in {} and {}.'.format(
|
||||
model_name, osp.abspath(metainfo.filepath),
|
||||
osp.abspath(cls._models_dict[model_name].filepath)))
|
||||
metainfo.config = cls._expand_config_path(metainfo, config_prefix)
|
||||
cls._models_dict[model_name] = metainfo
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_name):
|
||||
"""Get the model's metainfo by the model name.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of model.
|
||||
|
||||
Returns:
|
||||
modelindex.models.Model: The metainfo of the specified model.
|
||||
"""
|
||||
# lazy load config
|
||||
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
|
||||
if metainfo is None:
|
||||
raise ValueError(f'Failed to find model {model_name}.')
|
||||
if isinstance(metainfo.config, str):
|
||||
metainfo.config = Config.fromfile(metainfo.config)
|
||||
return metainfo
|
||||
|
||||
@staticmethod
|
||||
def _expand_config_path(metainfo: Model,
|
||||
config_prefix: Union[str, PathLike] = None):
|
||||
if config_prefix is None:
|
||||
config_prefix = osp.dirname(metainfo.filepath)
|
||||
|
||||
if metainfo.config is None or osp.isabs(metainfo.config):
|
||||
config_path: str = metainfo.config
|
||||
else:
|
||||
config_path = osp.abspath(osp.join(config_prefix, metainfo.config))
|
||||
|
||||
return config_path
|
||||
|
||||
|
||||
# register models in mmcls
|
||||
mmcls_root = Path(get_installed_path('mmcls'))
|
||||
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
|
||||
ModelHub.register_model_index(
|
||||
model_index_path, config_prefix=mmcls_root / '.mim')
|
||||
|
||||
|
||||
def init_model(config, checkpoint=None, device=None, **kwargs):
|
||||
"""Initialize a classifier from config file.
|
||||
|
||||
Args:
|
||||
config (str | :obj:`mmengine.Config`): Config file path or the config
|
||||
object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
device (str | torch.device | None): Transfer the model to the target
|
||||
device. Defaults to None.
|
||||
**kwargs: Other keyword arguments of the model config.
|
||||
|
||||
Returns:
|
||||
nn.Module: The constructed model.
|
||||
"""
|
||||
if isinstance(config, (str, PathLike)):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(config)}')
|
||||
if kwargs:
|
||||
config.merge_from_dict({'model': kwargs})
|
||||
config.model.setdefault('data_preprocessor',
|
||||
config.get('data_preprocessor', None))
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
# Mapping the weights to GPU may cause unexpected video memory leak
|
||||
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
if 'dataset_meta' in checkpoint.get('meta', {}):
|
||||
# mmcls 1.x
|
||||
model.CLASSES = checkpoint['meta']['dataset_meta']['classes']
|
||||
elif 'CLASSES' in checkpoint.get('meta', {}):
|
||||
# mmcls < 1.x
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
else:
|
||||
from mmcls.datasets.categories import IMAGENET_CATEGORIES
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn('Class names are not saved in the checkpoint\'s '
|
||||
'meta data, use imagenet by default.')
|
||||
model.CLASSES = IMAGENET_CATEGORIES
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def get_model(model_name, pretrained=False, device=None, **kwargs):
|
||||
"""Get a pre-defined model by the name of model.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of model.
|
||||
pretrained (bool | str): If True, load the pre-defined pretrained
|
||||
weights. If a string, load the weights from it. Defaults to False.
|
||||
device (str | torch.device | None): Transfer the model to the target
|
||||
device. Defaults to None.
|
||||
**kwargs: Other keyword arguments of the model config.
|
||||
|
||||
Returns:
|
||||
mmengine.model.BaseModel: The result model.
|
||||
|
||||
Examples:
|
||||
Get a ResNet-50 model and extract images feature:
|
||||
|
||||
>>> import torch
|
||||
>>> from mmcls import get_model
|
||||
>>> inputs = torch.rand(16, 3, 224, 224)
|
||||
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
|
||||
>>> feats = model.extract_feat(inputs)
|
||||
>>> for feat in feats:
|
||||
... print(feat.shape)
|
||||
torch.Size([16, 256])
|
||||
torch.Size([16, 512])
|
||||
torch.Size([16, 1024])
|
||||
torch.Size([16, 2048])
|
||||
|
||||
Get Swin-Transformer model with pre-trained weights and inference:
|
||||
|
||||
>>> from mmcls import get_model, inference_model
|
||||
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
|
||||
>>> result = inference_model(model, 'demo/demo.JPEG')
|
||||
>>> print(result['pred_class'])
|
||||
'sea snake'
|
||||
""" # noqa: E501
|
||||
metainfo = ModelHub.get(model_name)
|
||||
|
||||
if isinstance(pretrained, str):
|
||||
ckpt = pretrained
|
||||
elif pretrained:
|
||||
if metainfo.weights is None:
|
||||
raise ValueError(
|
||||
f"The model {model_name} doesn't have pretrained weights.")
|
||||
ckpt = metainfo.weights
|
||||
else:
|
||||
ckpt = None
|
||||
|
||||
if metainfo.config is None:
|
||||
raise ValueError(
|
||||
f"The model {model_name} doesn't support building by now.")
|
||||
model = init_model(metainfo.config, ckpt, device=device, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def list_models(pattern=None) -> List[str]:
|
||||
"""List all models available in MMClassification.
|
||||
|
||||
Args:
|
||||
pattern (str | None): A wildcard pattern to match model names.
|
||||
|
||||
Returns:
|
||||
List[str]: a list of model names.
|
||||
|
||||
Examples:
|
||||
List all models:
|
||||
|
||||
>>> from mmcls import list_models
|
||||
>>> print(list_models())
|
||||
|
||||
List ResNet-50 models on ImageNet-1k dataset:
|
||||
|
||||
>>> from mmcls import list_models
|
||||
>>> print(list_models('resnet*in1k'))
|
||||
['resnet50_8xb32_in1k',
|
||||
'resnet50_8xb32-fp16_in1k',
|
||||
'resnet50_8xb256-rsb-a1-600e_in1k',
|
||||
'resnet50_8xb256-rsb-a2-300e_in1k',
|
||||
'resnet50_8xb256-rsb-a3-100e_in1k']
|
||||
"""
|
||||
if pattern is None:
|
||||
return sorted(list(ModelHub._models_dict.keys()))
|
||||
# Always match keys with any postfix.
|
||||
matches = fnmatch.filter(ModelHub._models_dict.keys(), pattern + '*')
|
||||
return matches
|
|
@ -1,9 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
try:
|
||||
import timm
|
||||
except ImportError:
|
||||
timm = None
|
||||
|
||||
import warnings
|
||||
|
||||
from mmengine.logging import MMLogger
|
||||
|
@ -68,10 +63,13 @@ class TIMMBackbone(BaseBackbone):
|
|||
in_channels=3,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
if timm is None:
|
||||
raise RuntimeError(
|
||||
try:
|
||||
import timm
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Failed to import timm. Please run "pip install timm". '
|
||||
'"pip install dataclasses" may also be needed for Python 3.6.')
|
||||
|
||||
if not isinstance(pretrained, bool):
|
||||
raise TypeError('pretrained must be bool, not str for model path')
|
||||
if features_only and checkpoint_path:
|
||||
|
|
|
@ -4,6 +4,7 @@ from collections import OrderedDict
|
|||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
|
@ -96,7 +97,9 @@ class HuggingFaceClassifier(BaseClassifier):
|
|||
**kwargs)
|
||||
self.model = AutoModelForImageClassification.from_config(config)
|
||||
|
||||
self.loss_module = MODELS.build(loss)
|
||||
if not isinstance(loss, nn.Module):
|
||||
loss = MODELS.build(loss)
|
||||
self.loss_module = loss
|
||||
|
||||
self.with_cp = with_cp
|
||||
if self.with_cp:
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import ClsDataSample
|
||||
|
@ -61,13 +62,16 @@ class ImageClassifier(BaseClassifier):
|
|||
super(ImageClassifier, self).__init__(
|
||||
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
||||
|
||||
self.backbone = MODELS.build(backbone)
|
||||
if not isinstance(backbone, nn.Module):
|
||||
backbone = MODELS.build(backbone)
|
||||
if neck is not None and not isinstance(neck, nn.Module):
|
||||
neck = MODELS.build(neck)
|
||||
if head is not None and not isinstance(head, nn.Module):
|
||||
head = MODELS.build(head)
|
||||
|
||||
if neck is not None:
|
||||
self.neck = MODELS.build(neck)
|
||||
|
||||
if head is not None:
|
||||
self.head = MODELS.build(head)
|
||||
self.backbone = backbone
|
||||
self.neck = neck
|
||||
self.head = head
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
|
|
|
@ -4,6 +4,7 @@ from collections import OrderedDict
|
|||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
|
@ -79,7 +80,10 @@ class TimmClassifier(BaseClassifier):
|
|||
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
||||
from timm.models import create_model
|
||||
self.model = create_model(*args, **kwargs)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
if not isinstance(loss, nn.Module):
|
||||
loss = MODELS.build(loss)
|
||||
self.loss_module = loss
|
||||
|
||||
self.with_cp = with_cp
|
||||
if self.with_cp:
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcls.evaluation.metrics import Accuracy
|
||||
|
@ -34,7 +35,9 @@ class ClsHead(BaseHead):
|
|||
super(ClsHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.topk = topk
|
||||
self.loss_module = MODELS.build(loss)
|
||||
if not isinstance(loss, nn.Module):
|
||||
loss = MODELS.build(loss)
|
||||
self.loss_module = loss
|
||||
self.cal_acc = cal_acc
|
||||
|
||||
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
|
|
|
@ -157,7 +157,9 @@ class ArcFaceClsHead(ClsHead):
|
|||
init_cfg: Optional[dict] = None):
|
||||
|
||||
super(ArcFaceClsHead, self).__init__(init_cfg=init_cfg)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
if not isinstance(loss, nn.Module):
|
||||
loss = MODELS.build(loss)
|
||||
self.loss_module = loss
|
||||
|
||||
assert num_subcenters >= 1 and num_classes >= 0
|
||||
self.in_channels = in_channels
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
|
@ -36,7 +37,9 @@ class MultiLabelClsHead(BaseHead):
|
|||
init_cfg: Optional[dict] = None):
|
||||
super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.loss_module = MODELS.build(loss)
|
||||
if not isinstance(loss, nn.Module):
|
||||
loss = MODELS.build(loss)
|
||||
self.loss_module = loss
|
||||
|
||||
if thr is None and topk is None:
|
||||
thr = 0.5
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Callable, List, Optional, Union
|
|||
|
||||
import mmengine.dist as dist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.runner import Runner
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -75,8 +76,13 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
super(ImageToImageRetriever, self).__init__(
|
||||
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
||||
|
||||
self.image_encoder = MODELS.build(image_encoder)
|
||||
self.head = MODELS.build(head) if head else None
|
||||
if not isinstance(image_encoder, nn.Module):
|
||||
image_encoder = MODELS.build(image_encoder)
|
||||
if head is not None and not isinstance(head, nn.Module):
|
||||
head = MODELS.build(head)
|
||||
|
||||
self.image_encoder = image_encoder
|
||||
self.head = head
|
||||
|
||||
self.similarity = similarity_fn
|
||||
|
||||
|
|
|
@ -4,6 +4,6 @@ interrogate
|
|||
isort==4.3.21
|
||||
mmdet>=3.0.0rc0
|
||||
pytest
|
||||
sklearn
|
||||
scikit-learn
|
||||
xdoctest >= 0.10.0
|
||||
yapf
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
Models:
|
||||
- Name: test_model
|
||||
Metadata:
|
||||
FLOPs: 319000000
|
||||
Parameters: 3500000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 71.86
|
||||
Top 5 Accuracy: 90.42
|
||||
Task: Image Classification
|
||||
Weights: test_weight.pth
|
||||
Config: test_config.py
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from mmengine import Config
|
||||
|
||||
from mmcls.apis import ModelHub, get_model, init_model, list_models
|
||||
from mmcls.models import ImageClassifier, MobileNetV2
|
||||
|
||||
|
||||
class TestModelHub(TestCase):
|
||||
|
||||
def test_mmcls_models(self):
|
||||
self.assertIn('resnet18_8xb32_in1k', ModelHub._models_dict)
|
||||
|
||||
def test_register_model_index(self):
|
||||
model_index_path = osp.join(osp.dirname(__file__), '../data/meta.yml')
|
||||
|
||||
ModelHub.register_model_index(model_index_path)
|
||||
self.assertIn('test_model', ModelHub._models_dict)
|
||||
self.assertEqual(
|
||||
ModelHub._models_dict['test_model'].config,
|
||||
osp.abspath(
|
||||
osp.join(osp.dirname(model_index_path), 'test_config.py')))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'meta.yml'):
|
||||
# test name conflict
|
||||
ModelHub.register_model_index(model_index_path)
|
||||
|
||||
# test specify config prefix
|
||||
del ModelHub._models_dict['test_model']
|
||||
ModelHub.register_model_index(
|
||||
model_index_path, config_prefix='configs')
|
||||
self.assertEqual(ModelHub._models_dict['test_model'].config,
|
||||
osp.abspath(osp.join('configs', 'test_config.py')))
|
||||
|
||||
def test_get_model(self):
|
||||
metainfo = ModelHub.get('resnet18_8xb32_in1k')
|
||||
self.assertIsInstance(metainfo.weights, str)
|
||||
self.assertIsInstance(metainfo.config, Config)
|
||||
|
||||
|
||||
class TestHubAPIs(TestCase):
|
||||
|
||||
def test_list_models(self):
|
||||
models_names = list_models()
|
||||
self.assertIsInstance(models_names, list)
|
||||
|
||||
models_names = list_models(pattern='swin*in1k')
|
||||
for model_name in models_names:
|
||||
self.assertTrue(
|
||||
model_name.startswith('swin') and 'in1k' in model_name)
|
||||
|
||||
def test_get_model(self):
|
||||
model = get_model('mobilenet-v2_8xb32_in1k')
|
||||
self.assertIsInstance(model, ImageClassifier)
|
||||
self.assertIsInstance(model.backbone, MobileNetV2)
|
||||
|
||||
with patch('mmcls.apis.model.init_model') as mock:
|
||||
model = get_model('mobilenet-v2_8xb32_in1k', pretrained=True)
|
||||
model = get_model('mobilenet-v2_8xb32_in1k', pretrained='test.pth')
|
||||
|
||||
weight = mock.call_args_list[0][0][1]
|
||||
self.assertIn('https', weight)
|
||||
weight = mock.call_args_list[1][0][1]
|
||||
self.assertEqual('test.pth', weight)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'Failed to find'):
|
||||
get_model('unknown-model')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "doesn't support building"):
|
||||
get_model('swinv2-base-w12_3rdparty_in21k-192px')
|
||||
|
||||
def test_init_model(self):
|
||||
# test init from config object
|
||||
cfg = ModelHub.get('mobilenet-v2_8xb32_in1k').config
|
||||
model = init_model(cfg)
|
||||
self.assertIsInstance(model, ImageClassifier)
|
||||
self.assertIsInstance(model.backbone, MobileNetV2)
|
||||
|
||||
# test init from config file
|
||||
cfg = ModelHub._models_dict['mobilenet-v2_8xb32_in1k'].config
|
||||
self.assertIsInstance(cfg, str)
|
||||
model = init_model(cfg)
|
||||
self.assertIsInstance(model, ImageClassifier)
|
||||
self.assertIsInstance(model.backbone, MobileNetV2)
|
||||
|
||||
# test modify configs of the model
|
||||
model = init_model(cfg, head=dict(num_classes=10))
|
||||
self.assertEqual(model.head.num_classes, 10)
|
Loading…
Reference in New Issue