[Enhancement] Refine GlobalAccessble (#144)

* rename global accessible and intergration get_sintance and create_instance

* move ManagerMixin to utils

* fix as docstring and seporate get_instance to get_instance and get_current_instance

* fix lint

* fix docstring, rename and move test_global_meta

* fix manager's runtime error description

fix manager's runtime error description

* Add comments

* Add comments
pull/156/head
Mashiro 2022-03-26 21:21:25 +08:00 committed by GitHub
parent 2bf099d33c
commit 1048584147
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 284 additions and 330 deletions

View File

@ -36,3 +36,4 @@ Distributed
Logging
--------
.. automodule:: mmengine.logging
:members:

View File

@ -1,10 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_global_accsessible import BaseGlobalAccessible, MetaGlobalAccessible
from .log_buffer import LogBuffer
from .logger import MMLogger, print_log
from .message_hub import MessageHub
__all__ = [
'LogBuffer', 'MessageHub', 'MetaGlobalAccessible', 'BaseGlobalAccessible',
'MMLogger', 'print_log'
]
__all__ = ['LogBuffer', 'MessageHub', 'MMLogger', 'print_log']

View File

@ -1,173 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from collections import OrderedDict
from typing import Any, Optional
class MetaGlobalAccessible(type):
"""The metaclass for global accessible class.
The subclasses inheriting from ``MetaGlobalAccessible`` will manage their
own ``_instance_dict`` and root instances. The constructors of subclasses
must contain an optional ``name`` argument and all other arguments must
have default values.
Examples:
>>> class SubClass1(metaclass=MetaGlobalAccessible):
>>> def __init__(self, *args, **kwargs):
>>> pass
AssertionError: <class '__main__.SubClass1'>.__init__ must have the
name argument.
>>> class SubClass2(metaclass=MetaGlobalAccessible):
>>> def __init__(self, a, name=None, **kwargs):
>>> pass
AssertionError:
In <class '__main__.SubClass2'>.__init__, Only the name argument is
allowed to have no default values.
>>> class SubClass3(metaclass=MetaGlobalAccessible):
>>> def __init__(self, name, **kwargs):
>>> pass # Right format
>>> class SubClass4(metaclass=MetaGlobalAccessible):
>>> def __init__(self, a=1, name='', **kwargs):
>>> pass # Right format
"""
def __init__(cls, *args):
cls._instance_dict = OrderedDict()
params = inspect.getfullargspec(cls)
# `inspect.getfullargspec` returns a tuple includes `(args, varargs,
# varkw, defaults, kwonlyargs, kwonlydefaults, annotations)`.
# To make sure `cls(name='root')` can be implemented, the
# `args` and `defaults` should be checked.
params_names = params[0] if params[0] else []
default_params = params[3] if params[3] else []
assert 'name' in params_names, f'{cls}.__init__ must have the name ' \
'argument'
if len(default_params) == len(params_names) - 2 and 'name' != \
params[0][1]:
raise AssertionError(f'In {cls}.__init__, Only the name argument '
'is allowed to have no default values.')
if len(default_params) < len(params_names) - 2:
raise AssertionError('Besides name, the arguments of the '
f'{cls}.__init__ must have default values')
cls.root = cls(name='root')
super().__init__(*args)
class BaseGlobalAccessible(metaclass=MetaGlobalAccessible):
"""``BaseGlobalAccessible`` is the base class for classes that have global
access requirements.
The subclasses inheriting from ``BaseGlobalAccessible`` can get their
global instancees.
Examples:
>>> class GlobalAccessible(BaseGlobalAccessible):
>>> def __init__(self, name=''):
>>> super().__init__(name)
>>>
>>> GlobalAccessible.create_instance('name')
>>> instance_1 = GlobalAccessible.get_instance('name')
>>> instance_2 = GlobalAccessible.get_instance('name')
>>> assert id(instance_1) == id(instance_2)
Args:
name (str): Name of the instance. Defaults to ''.
"""
def __init__(self, name: str = '', **kwargs):
self._name = name
@classmethod
def create_instance(cls, name: str = '', **kwargs) -> Any:
"""Create subclass instance by name, and subclass cannot create
instances with duplicated names. The created instance will be stored in
``cls._instance_dict``, and can be accessed by ``get_instance``.
Examples:
>>> instance_1 = GlobalAccessible.create_instance('name')
>>> instance_2 = GlobalAccessible.create_instance('name')
AssertionError: <class '__main__.GlobalAccessible'> cannot be
created by name twice.
>>> root_instance = GlobalAccessible.create_instance()
>>> root_instance.instance_name # get default root instance
root
Args:
name (str): Name of instance. Defaults to ''.
Returns:
object: Subclass instance.
"""
instance_dict = cls._instance_dict
# Create instance and fill the instance in the `instance_dict`.
if name:
assert name not in instance_dict, f'{cls} cannot be created by ' \
f'{name} twice.'
instance = cls(name=name, **kwargs)
instance_dict[name] = instance
return instance
# Get default root instance.
else:
if kwargs:
raise ValueError('If name is not specified, create_instance '
f'will return root {cls} and cannot accept '
f'any arguments, but got kwargs: {kwargs}')
return cls.root
@classmethod
def get_instance(cls, name: str = '', current: bool = False) -> Any:
"""Get subclass instance by name if the name exists. if name is not
specified, this method will return latest created instance of root
instance.
Examples
>>> instance = GlobalAccessible.create_instance('name1')
>>> instance = GlobalAccessible.get_instance('name1')
>>> instance.instance_name
name1
>>> instance = GlobalAccessible.create_instance('name2')
>>> instance = GlobalAccessible.get_instance(current=True)
>>> instance.instance_name
name2
>>> instance = GlobalAccessible.get_instance()
>>> instance.instance_name # get root instance
root
>>> instance = GlobalAccessible.get_instance('name3') # error
AssertionError: Cannot get <class '__main__.GlobalAccessible'> by
name: name3, please make sure you have created it
Args:
name (str): Name of instance. Defaults to ''.
current(bool): Whether to return the latest created instance or
the root instance, if name is not spicified. Defaults to False.
Returns:
object: Corresponding name instance, the latest instance, or root
instance.
"""
instance_dict = cls._instance_dict
# Get the instance by name.
if name:
assert name in instance_dict, \
f'Cannot get {cls} by name: {name}, please make sure you ' \
'have created it'
return instance_dict[name]
# Get latest instantiated instance or root instance.
else:
if current:
current_name = next(iter(reversed(cls._instance_dict)))
assert current_name, f'Before calling {cls}.get_instance, ' \
'you should call create_instance.'
return cls._instance_dict[current_name]
else:
return cls.root
@property
def instance_name(self) -> Optional[str]:
"""Get the name of instance.
Returns:
str: Name of instance.
"""
return self._name

View File

@ -8,7 +8,7 @@ from typing import Optional, Union
import torch.distributed as dist
from termcolor import colored
from .base_global_accsessible import BaseGlobalAccessible
from mmengine.utils import ManagerMixin
class MMFormatter(logging.Formatter):
@ -84,10 +84,10 @@ class MMFormatter(logging.Formatter):
return result
class MMLogger(Logger, BaseGlobalAccessible):
class MMLogger(Logger, ManagerMixin):
"""The Logger manager which can create formatted logger and get specified
logger globally. MMLogger is created and accessed in the same way as
BaseGlobalAccessible.
ManagerMixin.
Args:
name (str): Logger name. Defaults to ''.
@ -104,7 +104,7 @@ class MMLogger(Logger, BaseGlobalAccessible):
log_level: str = 'NOTSET',
file_mode: str = 'w'):
Logger.__init__(self, name)
BaseGlobalAccessible.__init__(self, name)
ManagerMixin.__init__(self, name)
# Get rank in DDP mode.
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
@ -137,19 +137,22 @@ class MMLogger(Logger, BaseGlobalAccessible):
def print_log(msg,
logger: Optional[Union[Logger, str]] = None,
level=logging.INFO):
level=logging.INFO) -> None:
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (Logger or str, optional): The logger to be used.
logger (Logger or str, optional): If the type of logger is
``logging.Logger``, we directly use logger to log messages.
Some special loggers are:
- "silent": no message will be printed.
- "current": Log message via the latest created logger.
- other str: the logger obtained with `MMLogger.get_instance`.
- "silent": No message will be printed.
- "current": Use latest created logger to log message.
- other str: Instance name of logger. The corresponding logger
will log message if it has been created, otherwise ``print_log``
will raise a `ValueError`.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
object, "current", or a created logger instance name.
"""
if logger is None:
print(msg)
@ -158,13 +161,17 @@ def print_log(msg,
elif logger == 'silent':
pass
elif logger == 'current':
logger_instance = MMLogger.get_instance(current=True)
logger_instance = MMLogger.get_current_instance()
logger_instance.log(level, msg)
elif isinstance(logger, str):
try:
_logger = MMLogger.get_instance(logger)
_logger.log(level, msg)
except AssertionError:
# If the type of `logger` is `str`, but not with value of `current` or
# `silent`, we assume it indicates the name of the logger. If the
# corresponding logger has not been created, `print_log` will raise
# a `ValueError`.
if MMLogger.check_instance_created(logger):
logger_instance = MMLogger.get_instance(logger)
logger_instance.log(level, msg)
else:
raise ValueError(f'MMLogger: {logger} has not been created!')
else:
raise TypeError(

View File

@ -6,14 +6,14 @@ from typing import Any, Union
import numpy as np
import torch
from mmengine.utils import ManagerMixin
from mmengine.visualization.utils import check_type
from .base_global_accsessible import BaseGlobalAccessible
from .log_buffer import LogBuffer
class MessageHub(BaseGlobalAccessible):
class MessageHub(ManagerMixin):
"""Message hub for component interaction. MessageHub is created and
accessed in the same way as BaseGlobalAccessible.
accessed in the same way as ManagerMixin.
``MessageHub`` will record log information and runtime information. The
log information refers to the learning rate, loss, etc. of the model
@ -52,7 +52,7 @@ class MessageHub(BaseGlobalAccessible):
log_dict (str): Used for batch updating :attr:`_log_buffers`.
Examples:
>>> message_hub = MessageHub.create_instance()
>>> message_hub = MessageHub.get_instance('mmengine')
>>> log_dict = dict(a=1, b=2, c=3)
>>> message_hub.update_log_vars(log_dict)
>>> # The default count of `a`, `b` and `c` is 1.

View File

@ -602,7 +602,7 @@ class Runner:
'logger should be MMLogger object, a dict or None, '
f'but got {logger}')
return MMLogger.create_instance(**logger)
return MMLogger.get_instance(**logger)
def build_message_hub(
self,
@ -632,7 +632,7 @@ class Runner:
'message_hub should be MessageHub object, a dict or None, '
f'but got {message_hub}')
return MessageHub.create_instance(**message_hub)
return MessageHub.get_instance(**message_hub)
def build_writer(
self,
@ -664,7 +664,7 @@ class Runner:
'writer should be ComposedWriter object, a dict or None, '
f'but got {writer}')
return ComposedWriter.create_instance(**writer)
return ComposedWriter.get_instance(**writer)
def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
"""Build model.

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hub import load_url
from .manager import ManagerMeta, ManagerMixin
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
find_latest_checkpoint, has_method,
import_modules_from_strings, is_list_of,
@ -22,5 +23,5 @@ __all__ = [
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method', 'mmcv_full_available',
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
'find_latest_checkpoint'
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin'
]

View File

@ -0,0 +1,161 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import threading
from collections import OrderedDict
from typing import Any
_lock = threading.RLock()
def _accquire_lock() -> None:
"""Acquire the module-level lock for serializing access to shared data.
This should be released with _release_lock().
"""
if _lock:
_lock.acquire()
def _release_lock() -> None:
"""Release the module-level lock acquired by calling _accquire_lock()."""
if _lock:
_lock.release()
class ManagerMeta(type):
"""The metaclass for global accessible class.
The subclasses inheriting from ``ManagerMeta`` will manage their
own ``_instance_dict`` and root instances. The constructors of subclasses
must contain the ``name`` argument.
Examples:
>>> class SubClass1(metaclass=ManagerMeta):
>>> def __init__(self, *args, **kwargs):
>>> pass
AssertionError: <class '__main__.SubClass1'>.__init__ must have the
name argument.
>>> class SubClass2(metaclass=ManagerMeta):
>>> def __init__(self, name):
>>> pass
>>> # valid format.
"""
def __init__(cls, *args):
cls._instance_dict = OrderedDict()
params = inspect.getfullargspec(cls)
params_names = params[0] if params[0] else []
assert 'name' in params_names, f'{cls} must have the `name` argument'
super().__init__(*args)
class ManagerMixin(metaclass=ManagerMeta):
"""``ManagerMixin`` is the base class for classes that have global access
requirements.
The subclasses inheriting from ``ManagerMixin`` can get their
global instances.
Examples:
>>> class GlobalAccessible(ManagerMixin):
>>> def __init__(self, name=''):
>>> super().__init__(name)
>>>
>>> GlobalAccessible.get_instance('name')
>>> instance_1 = GlobalAccessible.get_instance('name')
>>> instance_2 = GlobalAccessible.get_instance('name')
>>> assert id(instance_1) == id(instance_2)
Args:
name (str): Name of the instance. Defaults to ''.
"""
def __init__(self, name: str = '', **kwargs):
self._instance_name = name
@classmethod
def get_instance(cls, name: str, **kwargs) -> Any:
"""Get subclass instance by name if the name exists.
If corresponding name instance has not been created, ``get_instance``
will create an instance, otherwise ``get_instance`` will return the
corresponding instance.
Examples
>>> instance1 = GlobalAccessible.get_instance('name1')
>>> # Create name1 instance.
>>> instance.instance_name
name1
>>> instance2 = GlobalAccessible.get_instance('name1')
>>> # Get name1 instance.
>>> assert id(instance1) == id(instance2)
Args:
name (str): Name of instance. Defaults to ''.
Returns:
object: Corresponding name instance, the latest instance, or root
instance.
"""
_accquire_lock()
assert isinstance(name, str), \
f'type of name should be str, but got {type(cls)}'
instance_dict = cls._instance_dict
# Get the instance by name.
if name not in instance_dict:
instance = cls(name=name, **kwargs)
instance_dict[name] = instance
# Get latest instantiated instance or root instance.
_release_lock()
return instance_dict[name]
@classmethod
def get_current_instance(cls):
"""Get latest created instance.
Before calling ``get_current_instance``, The subclass must have called
``get_instance(xxx)`` at least once.
Examples
>>> instance = GlobalAccessible.get_current_instance(current=True)
AssertionError: At least one of name and current needs to be set
>>> instance = GlobalAccessible.get_instance('name1')
>>> instance.instance_name
name1
>>> instance = GlobalAccessible.get_current_instance(current=True)
>>> instance.instance_name
name1
Returns:
object: Latest created instance.
"""
_accquire_lock()
if not cls._instance_dict:
raise RuntimeError(
f'Before calling {cls.__name__}.get_instance('
'current=True), '
'you should call get_instance(name=xxx) at least once.')
name = next(iter(reversed(cls._instance_dict)))
_release_lock()
return cls._instance_dict[name]
@classmethod
def check_instance_created(cls, name: str) -> bool:
"""Check whether the name corresponding instance exists.
Args:
name (str): Name of instance.
Returns:
bool: Whether the name corresponding instance exists.
"""
return name in cls._instance_dict
@property
def instance_name(self) -> str:
"""Get the name of instance.
Returns:
str: Name of instance.
"""
return self._instance_name

View File

@ -11,9 +11,8 @@ import torch
from mmengine.data import BaseDataSample
from mmengine.fileio import dump
from mmengine.logging import BaseGlobalAccessible
from mmengine.registry import VISUALIZERS, WRITERS
from mmengine.utils import TORCH_VERSION
from mmengine.utils import TORCH_VERSION, ManagerMixin
from .visualizer import Visualizer
@ -676,15 +675,15 @@ class TensorboardWriter(BaseWriter):
self._tensorboard.close()
class ComposedWriter(BaseGlobalAccessible):
class ComposedWriter(ManagerMixin):
"""Wrapper class to compose multiple a subclass of :class:`BaseWriter`
instances. By inheriting BaseGlobalAccessible, it can be accessed anywhere
once instantiated.
instances. By inheriting ManagerMixin, it can be accessed anywhere once
instantiated.
Examples:
>>> from mmengine.visualization import ComposedWriter
>>> import numpy as np
>>> composed_writer= ComposedWriter.create_instance( \
>>> composed_writer= ComposedWriter.get_instance( \
'composed_writer', writers=[dict(type='LocalWriter', \
visualizer=dict(type='DetVisualizer'), \
save_dir='temp_dir'), dict(type='WandbWriter')])

View File

@ -1,110 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
from mmengine.logging import BaseGlobalAccessible, MetaGlobalAccessible
class SubClassA(BaseGlobalAccessible):
def __init__(self, name='', *args, **kwargs):
super().__init__(name, *args, **kwargs)
class SubClassB(BaseGlobalAccessible):
def __init__(self, name='', *args, **kwargs):
super().__init__(name, *args, **kwargs)
class TestGlobalMeta:
def test_init(self):
# Subclass's constructor does not contain name arguments will raise an
# error.
with pytest.raises(AssertionError):
class SubClassNoName1(metaclass=MetaGlobalAccessible):
def __init__(self, a, *args, **kwargs):
pass
# The constructor of subclasses must have default values for all
# arguments except name. Since `MetaGlobalAccessible` cannot tell which
# parameter does not have ha default value, we should test invalid
# subclasses separately.
with pytest.raises(AssertionError):
class SubClassNoDefault1(metaclass=MetaGlobalAccessible):
def __init__(self, a, name='', *args, **kwargs):
pass
with pytest.raises(AssertionError):
class SubClassNoDefault2(metaclass=MetaGlobalAccessible):
def __init__(self, a, b, name='', *args, **kwargs):
pass
# Valid subclass.
class GlobalAccessible1(metaclass=MetaGlobalAccessible):
def __init__(self, name):
self.name = name
# Allow name not to be the first arguments.
class GlobalAccessible2(metaclass=MetaGlobalAccessible):
def __init__(self, a=1, name=''):
self.name = name
assert GlobalAccessible1.root.name == 'root'
class TestBaseGlobalAccessible:
def test_init(self):
# test get root instance.
assert BaseGlobalAccessible.root._name == 'root'
# test create instance by name.
base_cls = BaseGlobalAccessible('name')
assert base_cls._name == 'name'
def test_create_instance(self):
# SubClass should manage their own `_instance_dict`.
SubClassA.create_instance('instance_a')
SubClassB.create_instance('instance_b')
assert SubClassB._instance_dict != SubClassA._instance_dict
# test `message_hub` can create by name.
message_hub = SubClassA.create_instance('name1')
assert message_hub.instance_name == 'name1'
# test return root message_hub
message_hub = SubClassA.create_instance()
assert message_hub.instance_name == 'root'
# test default get root `message_hub`.
def test_get_instance(self):
message_hub = SubClassA.get_instance()
assert message_hub.instance_name == 'root'
# test default get latest `message_hub`.
message_hub = SubClassA.create_instance('name2')
message_hub = SubClassA.get_instance(current=True)
assert message_hub.instance_name == 'name2'
message_hub.mark = -1
# test get latest `message_hub` repeatedly.
message_hub = SubClassA.create_instance('name3')
assert message_hub.instance_name == 'name3'
message_hub = SubClassA.get_instance(current=True)
assert message_hub.instance_name == 'name3'
# test get root repeatedly.
message_hub = SubClassA.get_instance()
assert message_hub.instance_name == 'root'
# test get name1 repeatedly
message_hub = SubClassA.get_instance('name2')
assert message_hub.mark == -1
# create_instance will raise error if `name` is not specified and
# given other arguments
with pytest.raises(ValueError):
SubClassA.create_instance(a=1)

View File

@ -17,7 +17,7 @@ class TestLogger:
@patch('torch.distributed.is_initialized', lambda: True)
@patch('torch.distributed.is_available', lambda: True)
def test_init_rank0(self, tmp_path):
logger = MMLogger.create_instance('rank0.pkg1', log_level='INFO')
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
assert logger.name == 'rank0.pkg1'
assert logger.instance_name == 'rank0.pkg1'
# Logger get from `MMLogger.get_instance` does not inherit from
@ -30,7 +30,7 @@ class TestLogger:
# If `rank=0`, the `log_level` of stream_handler and file_handler
# depends on the given arguments.
tmp_file = tmp_path / 'tmp_file.log'
logger = MMLogger.create_instance(
logger = MMLogger.get_instance(
'rank0.pkg2', log_level='INFO', log_file=str(tmp_file))
assert isinstance(logger, logging.Logger)
assert len(logger.handlers) == 2
@ -47,7 +47,7 @@ class TestLogger:
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
tmp_file = tmp_path / 'tmp_file.log'
log_path = tmp_path / 'rank1_tmp_file.log'
logger = MMLogger.create_instance(
logger = MMLogger.get_instance(
'rank1.pkg2', log_level='INFO', log_file=str(tmp_file))
assert len(logger.handlers) == 2
assert logger.handlers[0].level == logging.ERROR
@ -60,7 +60,7 @@ class TestLogger:
def test_handler(self, capsys, tmp_path, log_level):
# test stream handler can output correct format logs
logger_name = f'test_stream_{str(log_level)}'
logger = MMLogger.create_instance(logger_name, log_level=log_level)
logger = MMLogger.get_instance(logger_name, log_level=log_level)
logger.log(level=log_level, msg='welcome')
out, _ = capsys.readouterr()
# Skip match colored INFO
@ -73,7 +73,7 @@ class TestLogger:
# test file_handler output plain text without color.
tmp_file = tmp_path / 'tmp_file.log'
logger_name = f'test_file_{log_level}'
logger = MMLogger.create_instance(
logger = MMLogger.get_instance(
logger_name, log_level=log_level, log_file=tmp_file)
logger.log(level=log_level, msg='welcome')
with open(tmp_file, 'r') as f:
@ -87,7 +87,7 @@ class TestLogger:
def test_erro_format(self, capsys):
# test error level log can output file path, function name and
# line number
logger = MMLogger.create_instance('test_error', log_level='INFO')
logger = MMLogger.get_instance('test_error', log_level='INFO')
logger.error('welcome')
lineno = sys._getframe().f_lineno - 1
file_path = __file__
@ -109,7 +109,7 @@ class TestLogger:
print_log('welcome', logger='silent')
out, _ = capsys.readouterr()
assert out == ''
logger = MMLogger.create_instance('test_print_log')
logger = MMLogger.get_instance('test_print_log')
# Test using specified logger
print_log('welcome', logger=logger)
out, _ = capsys.readouterr()

View File

@ -15,7 +15,7 @@ class TestMessageHub:
assert len(message_hub.log_buffers) == 0
def test_update_log(self):
message_hub = MessageHub.create_instance()
message_hub = MessageHub.get_instance('mmengine')
# test create target `LogBuffer` by name
message_hub.update_log('name', 1)
log_buffer = message_hub.log_buffers['name']
@ -26,7 +26,7 @@ class TestMessageHub:
# unmatched string will raise a key error
def test_update_info(self):
message_hub = MessageHub.create_instance()
message_hub = MessageHub.get_instance('mmengine')
# test runtime value can be overwritten.
message_hub.update_info('key', 2)
assert message_hub.runtime_info['key'] == 2
@ -34,7 +34,7 @@ class TestMessageHub:
assert message_hub.runtime_info['key'] == 1
def test_get_log_buffers(self):
message_hub = MessageHub.create_instance()
message_hub = MessageHub.get_instance('mmengine')
# Get undefined key will raise error
with pytest.raises(KeyError):
message_hub.get_log('unknown')
@ -50,7 +50,7 @@ class TestMessageHub:
assert (recorded_count == count).all()
def test_get_runtime(self):
message_hub = MessageHub.create_instance()
message_hub = MessageHub.get_instance('mmengine')
with pytest.raises(KeyError):
message_hub.get_info('unknown')
recorded_dict = dict(a=1, b=2)
@ -58,7 +58,7 @@ class TestMessageHub:
assert message_hub.get_info('test_value') == recorded_dict
def test_get_log_vars(self):
message_hub = MessageHub.create_instance()
message_hub = MessageHub.get_instance('mmengine')
log_dict = dict(
loss=1,
loss_cls=torch.tensor(2),

View File

@ -0,0 +1,72 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
from mmengine.utils import ManagerMeta, ManagerMixin
class SubClassA(ManagerMixin):
def __init__(self, name='', *args, **kwargs):
super().__init__(name, *args, **kwargs)
class SubClassB(ManagerMixin):
def __init__(self, name='', *args, **kwargs):
super().__init__(name, *args, **kwargs)
class TestGlobalMeta:
def test_init(self):
# Subclass's constructor does not contain name arguments will raise an
# error.
with pytest.raises(AssertionError):
class SubClassNoName1(metaclass=ManagerMeta):
def __init__(self, a, *args, **kwargs):
pass
# Valid subclass.
class GlobalAccessible1(metaclass=ManagerMeta):
def __init__(self, name):
self.name = name
class TestManagerMixin:
def test_init(self):
# test create instance by name.
base_cls = ManagerMixin('name')
assert base_cls.instance_name == 'name'
def test_get_instance(self):
# SubClass should manage their own `_instance_dict`.
with pytest.raises(RuntimeError):
SubClassA.get_current_instance()
SubClassA.get_instance('instance_a')
SubClassB.get_instance('instance_b')
assert SubClassB._instance_dict != SubClassA._instance_dict
# Test `message_hub` can create by name.
message_hub = SubClassA.get_instance('name1')
assert message_hub.instance_name == 'name1'
# No arguments will raise an assertion error.
SubClassA.get_instance('name2')
message_hub = SubClassA.get_current_instance()
message_hub.mark = -1
assert message_hub.instance_name == 'name2'
# Test get latest `message_hub` repeatedly.
message_hub = SubClassA.get_instance('name3')
assert message_hub.instance_name == 'name3'
message_hub = SubClassA.get_current_instance()
assert message_hub.instance_name == 'name3'
# Test get name2 repeatedly.
message_hub = SubClassA.get_instance('name2')
assert message_hub.mark == -1
# Non-string instance name will raise `AssertionError`.
with pytest.raises(AssertionError):
SubClassA.get_instance(name=1)

View File

@ -332,7 +332,7 @@ class TestComposedWriter:
assert len(composed_writer._writers) == 2
# test global
composed_writer = ComposedWriter.create_instance(
composed_writer = ComposedWriter.get_instance(
'composed_writer',
writers=[
WandbWriter(),