89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
from unittest import TestCase
|
|
from unittest.mock import patch
|
|
|
|
from mmengine import Config
|
|
|
|
from mmpretrain.apis import ModelHub, get_model, init_model, list_models
|
|
from mmpretrain.models import ImageClassifier, MobileNetV2
|
|
|
|
|
|
class TestModelHub(TestCase):
|
|
|
|
def test_mmpretrain_models(self):
|
|
self.assertIn('resnet18_8xb32_in1k', ModelHub._models_dict)
|
|
|
|
def test_register_model_index(self):
|
|
model_index_path = osp.join(osp.dirname(__file__), '../data/meta.yml')
|
|
|
|
ModelHub.register_model_index(model_index_path)
|
|
self.assertIn('test_model', ModelHub._models_dict)
|
|
self.assertEqual(
|
|
ModelHub._models_dict['test_model'].config,
|
|
osp.abspath(
|
|
osp.join(osp.dirname(model_index_path), 'test_config.py')))
|
|
|
|
with self.assertRaisesRegex(ValueError, 'meta.yml'):
|
|
# test name conflict
|
|
ModelHub.register_model_index(model_index_path)
|
|
|
|
# test specify config prefix
|
|
del ModelHub._models_dict['test_model']
|
|
ModelHub.register_model_index(
|
|
model_index_path, config_prefix='configs')
|
|
self.assertEqual(ModelHub._models_dict['test_model'].config,
|
|
osp.abspath(osp.join('configs', 'test_config.py')))
|
|
|
|
def test_get_model(self):
|
|
metainfo = ModelHub.get('resnet18_8xb32_in1k')
|
|
self.assertIsInstance(metainfo.weights, str)
|
|
self.assertIsInstance(metainfo.config, Config)
|
|
|
|
|
|
class TestHubAPIs(TestCase):
|
|
|
|
def test_list_models(self):
|
|
models_names = list_models()
|
|
self.assertIsInstance(models_names, list)
|
|
|
|
models_names = list_models(pattern='swin*in1k')
|
|
for model_name in models_names:
|
|
self.assertTrue(
|
|
model_name.startswith('swin') and 'in1k' in model_name)
|
|
|
|
def test_get_model(self):
|
|
model = get_model('mobilenet-v2_8xb32_in1k')
|
|
self.assertIsInstance(model, ImageClassifier)
|
|
self.assertIsInstance(model.backbone, MobileNetV2)
|
|
|
|
with patch('mmengine.runner.load_checkpoint') as mock:
|
|
model = get_model('mobilenet-v2_8xb32_in1k', pretrained=True)
|
|
model = get_model('mobilenet-v2_8xb32_in1k', pretrained='test.pth')
|
|
|
|
weight = mock.call_args_list[0][0][1]
|
|
self.assertIn('https', weight)
|
|
weight = mock.call_args_list[1][0][1]
|
|
self.assertEqual('test.pth', weight)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'Failed to find'):
|
|
get_model('unknown-model')
|
|
|
|
def test_init_model(self):
|
|
# test init from config object
|
|
cfg = ModelHub.get('mobilenet-v2_8xb32_in1k').config
|
|
model = init_model(cfg)
|
|
self.assertIsInstance(model, ImageClassifier)
|
|
self.assertIsInstance(model.backbone, MobileNetV2)
|
|
|
|
# test init from config file
|
|
cfg = ModelHub._models_dict['mobilenet-v2_8xb32_in1k'].config
|
|
self.assertIsInstance(cfg, str)
|
|
model = init_model(cfg)
|
|
self.assertIsInstance(model, ImageClassifier)
|
|
self.assertIsInstance(model.backbone, MobileNetV2)
|
|
|
|
# test modify configs of the model
|
|
model = init_model(cfg, head=dict(num_classes=10))
|
|
self.assertEqual(model.head.num_classes, 10)
|