[Enhancement] Refine GlobalAccessble (#144)
* rename global accessible and intergration get_sintance and create_instance * move ManagerMixin to utils * fix as docstring and seporate get_instance to get_instance and get_current_instance * fix lint * fix docstring, rename and move test_global_meta * fix manager's runtime error description fix manager's runtime error description * Add comments * Add commentspull/156/head
parent
2bf099d33c
commit
1048584147
|
@ -36,3 +36,4 @@ Distributed
|
|||
Logging
|
||||
--------
|
||||
.. automodule:: mmengine.logging
|
||||
:members:
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_global_accsessible import BaseGlobalAccessible, MetaGlobalAccessible
|
||||
from .log_buffer import LogBuffer
|
||||
from .logger import MMLogger, print_log
|
||||
from .message_hub import MessageHub
|
||||
|
||||
__all__ = [
|
||||
'LogBuffer', 'MessageHub', 'MetaGlobalAccessible', 'BaseGlobalAccessible',
|
||||
'MMLogger', 'print_log'
|
||||
]
|
||||
__all__ = ['LogBuffer', 'MessageHub', 'MMLogger', 'print_log']
|
||||
|
|
|
@ -1,173 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class MetaGlobalAccessible(type):
|
||||
"""The metaclass for global accessible class.
|
||||
|
||||
The subclasses inheriting from ``MetaGlobalAccessible`` will manage their
|
||||
own ``_instance_dict`` and root instances. The constructors of subclasses
|
||||
must contain an optional ``name`` argument and all other arguments must
|
||||
have default values.
|
||||
|
||||
Examples:
|
||||
>>> class SubClass1(metaclass=MetaGlobalAccessible):
|
||||
>>> def __init__(self, *args, **kwargs):
|
||||
>>> pass
|
||||
AssertionError: <class '__main__.SubClass1'>.__init__ must have the
|
||||
name argument.
|
||||
>>> class SubClass2(metaclass=MetaGlobalAccessible):
|
||||
>>> def __init__(self, a, name=None, **kwargs):
|
||||
>>> pass
|
||||
AssertionError:
|
||||
In <class '__main__.SubClass2'>.__init__, Only the name argument is
|
||||
allowed to have no default values.
|
||||
>>> class SubClass3(metaclass=MetaGlobalAccessible):
|
||||
>>> def __init__(self, name, **kwargs):
|
||||
>>> pass # Right format
|
||||
>>> class SubClass4(metaclass=MetaGlobalAccessible):
|
||||
>>> def __init__(self, a=1, name='', **kwargs):
|
||||
>>> pass # Right format
|
||||
"""
|
||||
|
||||
def __init__(cls, *args):
|
||||
cls._instance_dict = OrderedDict()
|
||||
params = inspect.getfullargspec(cls)
|
||||
# `inspect.getfullargspec` returns a tuple includes `(args, varargs,
|
||||
# varkw, defaults, kwonlyargs, kwonlydefaults, annotations)`.
|
||||
# To make sure `cls(name='root')` can be implemented, the
|
||||
# `args` and `defaults` should be checked.
|
||||
params_names = params[0] if params[0] else []
|
||||
default_params = params[3] if params[3] else []
|
||||
assert 'name' in params_names, f'{cls}.__init__ must have the name ' \
|
||||
'argument'
|
||||
if len(default_params) == len(params_names) - 2 and 'name' != \
|
||||
params[0][1]:
|
||||
raise AssertionError(f'In {cls}.__init__, Only the name argument '
|
||||
'is allowed to have no default values.')
|
||||
if len(default_params) < len(params_names) - 2:
|
||||
raise AssertionError('Besides name, the arguments of the '
|
||||
f'{cls}.__init__ must have default values')
|
||||
cls.root = cls(name='root')
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class BaseGlobalAccessible(metaclass=MetaGlobalAccessible):
|
||||
"""``BaseGlobalAccessible`` is the base class for classes that have global
|
||||
access requirements.
|
||||
|
||||
The subclasses inheriting from ``BaseGlobalAccessible`` can get their
|
||||
global instancees.
|
||||
|
||||
Examples:
|
||||
>>> class GlobalAccessible(BaseGlobalAccessible):
|
||||
>>> def __init__(self, name=''):
|
||||
>>> super().__init__(name)
|
||||
>>>
|
||||
>>> GlobalAccessible.create_instance('name')
|
||||
>>> instance_1 = GlobalAccessible.get_instance('name')
|
||||
>>> instance_2 = GlobalAccessible.get_instance('name')
|
||||
>>> assert id(instance_1) == id(instance_2)
|
||||
|
||||
Args:
|
||||
name (str): Name of the instance. Defaults to ''.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = '', **kwargs):
|
||||
self._name = name
|
||||
|
||||
@classmethod
|
||||
def create_instance(cls, name: str = '', **kwargs) -> Any:
|
||||
"""Create subclass instance by name, and subclass cannot create
|
||||
instances with duplicated names. The created instance will be stored in
|
||||
``cls._instance_dict``, and can be accessed by ``get_instance``.
|
||||
|
||||
Examples:
|
||||
>>> instance_1 = GlobalAccessible.create_instance('name')
|
||||
>>> instance_2 = GlobalAccessible.create_instance('name')
|
||||
AssertionError: <class '__main__.GlobalAccessible'> cannot be
|
||||
created by name twice.
|
||||
>>> root_instance = GlobalAccessible.create_instance()
|
||||
>>> root_instance.instance_name # get default root instance
|
||||
root
|
||||
|
||||
Args:
|
||||
name (str): Name of instance. Defaults to ''.
|
||||
|
||||
Returns:
|
||||
object: Subclass instance.
|
||||
"""
|
||||
instance_dict = cls._instance_dict
|
||||
# Create instance and fill the instance in the `instance_dict`.
|
||||
if name:
|
||||
assert name not in instance_dict, f'{cls} cannot be created by ' \
|
||||
f'{name} twice.'
|
||||
instance = cls(name=name, **kwargs)
|
||||
instance_dict[name] = instance
|
||||
return instance
|
||||
# Get default root instance.
|
||||
else:
|
||||
if kwargs:
|
||||
raise ValueError('If name is not specified, create_instance '
|
||||
f'will return root {cls} and cannot accept '
|
||||
f'any arguments, but got kwargs: {kwargs}')
|
||||
return cls.root
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name: str = '', current: bool = False) -> Any:
|
||||
"""Get subclass instance by name if the name exists. if name is not
|
||||
specified, this method will return latest created instance of root
|
||||
instance.
|
||||
|
||||
Examples
|
||||
>>> instance = GlobalAccessible.create_instance('name1')
|
||||
>>> instance = GlobalAccessible.get_instance('name1')
|
||||
>>> instance.instance_name
|
||||
name1
|
||||
>>> instance = GlobalAccessible.create_instance('name2')
|
||||
>>> instance = GlobalAccessible.get_instance(current=True)
|
||||
>>> instance.instance_name
|
||||
name2
|
||||
>>> instance = GlobalAccessible.get_instance()
|
||||
>>> instance.instance_name # get root instance
|
||||
root
|
||||
>>> instance = GlobalAccessible.get_instance('name3') # error
|
||||
AssertionError: Cannot get <class '__main__.GlobalAccessible'> by
|
||||
name: name3, please make sure you have created it
|
||||
|
||||
Args:
|
||||
name (str): Name of instance. Defaults to ''.
|
||||
current(bool): Whether to return the latest created instance or
|
||||
the root instance, if name is not spicified. Defaults to False.
|
||||
|
||||
Returns:
|
||||
object: Corresponding name instance, the latest instance, or root
|
||||
instance.
|
||||
"""
|
||||
instance_dict = cls._instance_dict
|
||||
# Get the instance by name.
|
||||
if name:
|
||||
assert name in instance_dict, \
|
||||
f'Cannot get {cls} by name: {name}, please make sure you ' \
|
||||
'have created it'
|
||||
return instance_dict[name]
|
||||
# Get latest instantiated instance or root instance.
|
||||
else:
|
||||
if current:
|
||||
current_name = next(iter(reversed(cls._instance_dict)))
|
||||
assert current_name, f'Before calling {cls}.get_instance, ' \
|
||||
'you should call create_instance.'
|
||||
return cls._instance_dict[current_name]
|
||||
else:
|
||||
return cls.root
|
||||
|
||||
@property
|
||||
def instance_name(self) -> Optional[str]:
|
||||
"""Get the name of instance.
|
||||
|
||||
Returns:
|
||||
str: Name of instance.
|
||||
"""
|
||||
return self._name
|
|
@ -8,7 +8,7 @@ from typing import Optional, Union
|
|||
import torch.distributed as dist
|
||||
from termcolor import colored
|
||||
|
||||
from .base_global_accsessible import BaseGlobalAccessible
|
||||
from mmengine.utils import ManagerMixin
|
||||
|
||||
|
||||
class MMFormatter(logging.Formatter):
|
||||
|
@ -84,10 +84,10 @@ class MMFormatter(logging.Formatter):
|
|||
return result
|
||||
|
||||
|
||||
class MMLogger(Logger, BaseGlobalAccessible):
|
||||
class MMLogger(Logger, ManagerMixin):
|
||||
"""The Logger manager which can create formatted logger and get specified
|
||||
logger globally. MMLogger is created and accessed in the same way as
|
||||
BaseGlobalAccessible.
|
||||
ManagerMixin.
|
||||
|
||||
Args:
|
||||
name (str): Logger name. Defaults to ''.
|
||||
|
@ -104,7 +104,7 @@ class MMLogger(Logger, BaseGlobalAccessible):
|
|||
log_level: str = 'NOTSET',
|
||||
file_mode: str = 'w'):
|
||||
Logger.__init__(self, name)
|
||||
BaseGlobalAccessible.__init__(self, name)
|
||||
ManagerMixin.__init__(self, name)
|
||||
# Get rank in DDP mode.
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
|
@ -137,19 +137,22 @@ class MMLogger(Logger, BaseGlobalAccessible):
|
|||
|
||||
def print_log(msg,
|
||||
logger: Optional[Union[Logger, str]] = None,
|
||||
level=logging.INFO):
|
||||
level=logging.INFO) -> None:
|
||||
"""Print a log message.
|
||||
|
||||
Args:
|
||||
msg (str): The message to be logged.
|
||||
logger (Logger or str, optional): The logger to be used.
|
||||
logger (Logger or str, optional): If the type of logger is
|
||||
``logging.Logger``, we directly use logger to log messages.
|
||||
Some special loggers are:
|
||||
- "silent": no message will be printed.
|
||||
- "current": Log message via the latest created logger.
|
||||
- other str: the logger obtained with `MMLogger.get_instance`.
|
||||
- "silent": No message will be printed.
|
||||
- "current": Use latest created logger to log message.
|
||||
- other str: Instance name of logger. The corresponding logger
|
||||
will log message if it has been created, otherwise ``print_log``
|
||||
will raise a `ValueError`.
|
||||
- None: The `print()` method will be used to print log messages.
|
||||
level (int): Logging level. Only available when `logger` is a Logger
|
||||
object or "root".
|
||||
object, "current", or a created logger instance name.
|
||||
"""
|
||||
if logger is None:
|
||||
print(msg)
|
||||
|
@ -158,13 +161,17 @@ def print_log(msg,
|
|||
elif logger == 'silent':
|
||||
pass
|
||||
elif logger == 'current':
|
||||
logger_instance = MMLogger.get_instance(current=True)
|
||||
logger_instance = MMLogger.get_current_instance()
|
||||
logger_instance.log(level, msg)
|
||||
elif isinstance(logger, str):
|
||||
try:
|
||||
_logger = MMLogger.get_instance(logger)
|
||||
_logger.log(level, msg)
|
||||
except AssertionError:
|
||||
# If the type of `logger` is `str`, but not with value of `current` or
|
||||
# `silent`, we assume it indicates the name of the logger. If the
|
||||
# corresponding logger has not been created, `print_log` will raise
|
||||
# a `ValueError`.
|
||||
if MMLogger.check_instance_created(logger):
|
||||
logger_instance = MMLogger.get_instance(logger)
|
||||
logger_instance.log(level, msg)
|
||||
else:
|
||||
raise ValueError(f'MMLogger: {logger} has not been created!')
|
||||
else:
|
||||
raise TypeError(
|
||||
|
|
|
@ -6,14 +6,14 @@ from typing import Any, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.utils import ManagerMixin
|
||||
from mmengine.visualization.utils import check_type
|
||||
from .base_global_accsessible import BaseGlobalAccessible
|
||||
from .log_buffer import LogBuffer
|
||||
|
||||
|
||||
class MessageHub(BaseGlobalAccessible):
|
||||
class MessageHub(ManagerMixin):
|
||||
"""Message hub for component interaction. MessageHub is created and
|
||||
accessed in the same way as BaseGlobalAccessible.
|
||||
accessed in the same way as ManagerMixin.
|
||||
|
||||
``MessageHub`` will record log information and runtime information. The
|
||||
log information refers to the learning rate, loss, etc. of the model
|
||||
|
@ -52,7 +52,7 @@ class MessageHub(BaseGlobalAccessible):
|
|||
log_dict (str): Used for batch updating :attr:`_log_buffers`.
|
||||
|
||||
Examples:
|
||||
>>> message_hub = MessageHub.create_instance()
|
||||
>>> message_hub = MessageHub.get_instance('mmengine')
|
||||
>>> log_dict = dict(a=1, b=2, c=3)
|
||||
>>> message_hub.update_log_vars(log_dict)
|
||||
>>> # The default count of `a`, `b` and `c` is 1.
|
||||
|
|
|
@ -602,7 +602,7 @@ class Runner:
|
|||
'logger should be MMLogger object, a dict or None, '
|
||||
f'but got {logger}')
|
||||
|
||||
return MMLogger.create_instance(**logger)
|
||||
return MMLogger.get_instance(**logger)
|
||||
|
||||
def build_message_hub(
|
||||
self,
|
||||
|
@ -632,7 +632,7 @@ class Runner:
|
|||
'message_hub should be MessageHub object, a dict or None, '
|
||||
f'but got {message_hub}')
|
||||
|
||||
return MessageHub.create_instance(**message_hub)
|
||||
return MessageHub.get_instance(**message_hub)
|
||||
|
||||
def build_writer(
|
||||
self,
|
||||
|
@ -664,7 +664,7 @@ class Runner:
|
|||
'writer should be ComposedWriter object, a dict or None, '
|
||||
f'but got {writer}')
|
||||
|
||||
return ComposedWriter.create_instance(**writer)
|
||||
return ComposedWriter.get_instance(**writer)
|
||||
|
||||
def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
|
||||
"""Build model.
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hub import load_url
|
||||
from .manager import ManagerMeta, ManagerMixin
|
||||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||
find_latest_checkpoint, has_method,
|
||||
import_modules_from_strings, is_list_of,
|
||||
|
@ -22,5 +23,5 @@ __all__ = [
|
|||
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
||||
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
||||
'find_latest_checkpoint'
|
||||
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import inspect
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
_lock = threading.RLock()
|
||||
|
||||
|
||||
def _accquire_lock() -> None:
|
||||
"""Acquire the module-level lock for serializing access to shared data.
|
||||
|
||||
This should be released with _release_lock().
|
||||
"""
|
||||
if _lock:
|
||||
_lock.acquire()
|
||||
|
||||
|
||||
def _release_lock() -> None:
|
||||
"""Release the module-level lock acquired by calling _accquire_lock()."""
|
||||
if _lock:
|
||||
_lock.release()
|
||||
|
||||
|
||||
class ManagerMeta(type):
|
||||
"""The metaclass for global accessible class.
|
||||
|
||||
The subclasses inheriting from ``ManagerMeta`` will manage their
|
||||
own ``_instance_dict`` and root instances. The constructors of subclasses
|
||||
must contain the ``name`` argument.
|
||||
|
||||
Examples:
|
||||
>>> class SubClass1(metaclass=ManagerMeta):
|
||||
>>> def __init__(self, *args, **kwargs):
|
||||
>>> pass
|
||||
AssertionError: <class '__main__.SubClass1'>.__init__ must have the
|
||||
name argument.
|
||||
>>> class SubClass2(metaclass=ManagerMeta):
|
||||
>>> def __init__(self, name):
|
||||
>>> pass
|
||||
>>> # valid format.
|
||||
"""
|
||||
|
||||
def __init__(cls, *args):
|
||||
cls._instance_dict = OrderedDict()
|
||||
params = inspect.getfullargspec(cls)
|
||||
params_names = params[0] if params[0] else []
|
||||
assert 'name' in params_names, f'{cls} must have the `name` argument'
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class ManagerMixin(metaclass=ManagerMeta):
|
||||
"""``ManagerMixin`` is the base class for classes that have global access
|
||||
requirements.
|
||||
|
||||
The subclasses inheriting from ``ManagerMixin`` can get their
|
||||
global instances.
|
||||
|
||||
Examples:
|
||||
>>> class GlobalAccessible(ManagerMixin):
|
||||
>>> def __init__(self, name=''):
|
||||
>>> super().__init__(name)
|
||||
>>>
|
||||
>>> GlobalAccessible.get_instance('name')
|
||||
>>> instance_1 = GlobalAccessible.get_instance('name')
|
||||
>>> instance_2 = GlobalAccessible.get_instance('name')
|
||||
>>> assert id(instance_1) == id(instance_2)
|
||||
|
||||
Args:
|
||||
name (str): Name of the instance. Defaults to ''.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = '', **kwargs):
|
||||
self._instance_name = name
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name: str, **kwargs) -> Any:
|
||||
"""Get subclass instance by name if the name exists.
|
||||
|
||||
If corresponding name instance has not been created, ``get_instance``
|
||||
will create an instance, otherwise ``get_instance`` will return the
|
||||
corresponding instance.
|
||||
|
||||
Examples
|
||||
>>> instance1 = GlobalAccessible.get_instance('name1')
|
||||
>>> # Create name1 instance.
|
||||
>>> instance.instance_name
|
||||
name1
|
||||
>>> instance2 = GlobalAccessible.get_instance('name1')
|
||||
>>> # Get name1 instance.
|
||||
>>> assert id(instance1) == id(instance2)
|
||||
|
||||
Args:
|
||||
name (str): Name of instance. Defaults to ''.
|
||||
|
||||
Returns:
|
||||
object: Corresponding name instance, the latest instance, or root
|
||||
instance.
|
||||
"""
|
||||
_accquire_lock()
|
||||
assert isinstance(name, str), \
|
||||
f'type of name should be str, but got {type(cls)}'
|
||||
instance_dict = cls._instance_dict
|
||||
# Get the instance by name.
|
||||
if name not in instance_dict:
|
||||
instance = cls(name=name, **kwargs)
|
||||
instance_dict[name] = instance
|
||||
# Get latest instantiated instance or root instance.
|
||||
_release_lock()
|
||||
return instance_dict[name]
|
||||
|
||||
@classmethod
|
||||
def get_current_instance(cls):
|
||||
"""Get latest created instance.
|
||||
|
||||
Before calling ``get_current_instance``, The subclass must have called
|
||||
``get_instance(xxx)`` at least once.
|
||||
|
||||
Examples
|
||||
>>> instance = GlobalAccessible.get_current_instance(current=True)
|
||||
AssertionError: At least one of name and current needs to be set
|
||||
>>> instance = GlobalAccessible.get_instance('name1')
|
||||
>>> instance.instance_name
|
||||
name1
|
||||
>>> instance = GlobalAccessible.get_current_instance(current=True)
|
||||
>>> instance.instance_name
|
||||
name1
|
||||
|
||||
Returns:
|
||||
object: Latest created instance.
|
||||
"""
|
||||
_accquire_lock()
|
||||
if not cls._instance_dict:
|
||||
raise RuntimeError(
|
||||
f'Before calling {cls.__name__}.get_instance('
|
||||
'current=True), '
|
||||
'you should call get_instance(name=xxx) at least once.')
|
||||
name = next(iter(reversed(cls._instance_dict)))
|
||||
_release_lock()
|
||||
return cls._instance_dict[name]
|
||||
|
||||
@classmethod
|
||||
def check_instance_created(cls, name: str) -> bool:
|
||||
"""Check whether the name corresponding instance exists.
|
||||
|
||||
Args:
|
||||
name (str): Name of instance.
|
||||
|
||||
Returns:
|
||||
bool: Whether the name corresponding instance exists.
|
||||
"""
|
||||
return name in cls._instance_dict
|
||||
|
||||
@property
|
||||
def instance_name(self) -> str:
|
||||
"""Get the name of instance.
|
||||
|
||||
Returns:
|
||||
str: Name of instance.
|
||||
"""
|
||||
return self._instance_name
|
|
@ -11,9 +11,8 @@ import torch
|
|||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.fileio import dump
|
||||
from mmengine.logging import BaseGlobalAccessible
|
||||
from mmengine.registry import VISUALIZERS, WRITERS
|
||||
from mmengine.utils import TORCH_VERSION
|
||||
from mmengine.utils import TORCH_VERSION, ManagerMixin
|
||||
from .visualizer import Visualizer
|
||||
|
||||
|
||||
|
@ -676,15 +675,15 @@ class TensorboardWriter(BaseWriter):
|
|||
self._tensorboard.close()
|
||||
|
||||
|
||||
class ComposedWriter(BaseGlobalAccessible):
|
||||
class ComposedWriter(ManagerMixin):
|
||||
"""Wrapper class to compose multiple a subclass of :class:`BaseWriter`
|
||||
instances. By inheriting BaseGlobalAccessible, it can be accessed anywhere
|
||||
once instantiated.
|
||||
instances. By inheriting ManagerMixin, it can be accessed anywhere once
|
||||
instantiated.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine.visualization import ComposedWriter
|
||||
>>> import numpy as np
|
||||
>>> composed_writer= ComposedWriter.create_instance( \
|
||||
>>> composed_writer= ComposedWriter.get_instance( \
|
||||
'composed_writer', writers=[dict(type='LocalWriter', \
|
||||
visualizer=dict(type='DetVisualizer'), \
|
||||
save_dir='temp_dir'), dict(type='WandbWriter')])
|
||||
|
|
|
@ -1,110 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
|
||||
from mmengine.logging import BaseGlobalAccessible, MetaGlobalAccessible
|
||||
|
||||
|
||||
class SubClassA(BaseGlobalAccessible):
|
||||
|
||||
def __init__(self, name='', *args, **kwargs):
|
||||
super().__init__(name, *args, **kwargs)
|
||||
|
||||
|
||||
class SubClassB(BaseGlobalAccessible):
|
||||
|
||||
def __init__(self, name='', *args, **kwargs):
|
||||
super().__init__(name, *args, **kwargs)
|
||||
|
||||
|
||||
class TestGlobalMeta:
|
||||
|
||||
def test_init(self):
|
||||
# Subclass's constructor does not contain name arguments will raise an
|
||||
# error.
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
class SubClassNoName1(metaclass=MetaGlobalAccessible):
|
||||
|
||||
def __init__(self, a, *args, **kwargs):
|
||||
pass
|
||||
|
||||
# The constructor of subclasses must have default values for all
|
||||
# arguments except name. Since `MetaGlobalAccessible` cannot tell which
|
||||
# parameter does not have ha default value, we should test invalid
|
||||
# subclasses separately.
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
class SubClassNoDefault1(metaclass=MetaGlobalAccessible):
|
||||
|
||||
def __init__(self, a, name='', *args, **kwargs):
|
||||
pass
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
class SubClassNoDefault2(metaclass=MetaGlobalAccessible):
|
||||
|
||||
def __init__(self, a, b, name='', *args, **kwargs):
|
||||
pass
|
||||
|
||||
# Valid subclass.
|
||||
class GlobalAccessible1(metaclass=MetaGlobalAccessible):
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
# Allow name not to be the first arguments.
|
||||
|
||||
class GlobalAccessible2(metaclass=MetaGlobalAccessible):
|
||||
|
||||
def __init__(self, a=1, name=''):
|
||||
self.name = name
|
||||
|
||||
assert GlobalAccessible1.root.name == 'root'
|
||||
|
||||
|
||||
class TestBaseGlobalAccessible:
|
||||
|
||||
def test_init(self):
|
||||
# test get root instance.
|
||||
assert BaseGlobalAccessible.root._name == 'root'
|
||||
# test create instance by name.
|
||||
base_cls = BaseGlobalAccessible('name')
|
||||
assert base_cls._name == 'name'
|
||||
|
||||
def test_create_instance(self):
|
||||
# SubClass should manage their own `_instance_dict`.
|
||||
SubClassA.create_instance('instance_a')
|
||||
SubClassB.create_instance('instance_b')
|
||||
assert SubClassB._instance_dict != SubClassA._instance_dict
|
||||
|
||||
# test `message_hub` can create by name.
|
||||
message_hub = SubClassA.create_instance('name1')
|
||||
assert message_hub.instance_name == 'name1'
|
||||
# test return root message_hub
|
||||
message_hub = SubClassA.create_instance()
|
||||
assert message_hub.instance_name == 'root'
|
||||
# test default get root `message_hub`.
|
||||
|
||||
def test_get_instance(self):
|
||||
message_hub = SubClassA.get_instance()
|
||||
assert message_hub.instance_name == 'root'
|
||||
# test default get latest `message_hub`.
|
||||
message_hub = SubClassA.create_instance('name2')
|
||||
message_hub = SubClassA.get_instance(current=True)
|
||||
assert message_hub.instance_name == 'name2'
|
||||
message_hub.mark = -1
|
||||
# test get latest `message_hub` repeatedly.
|
||||
message_hub = SubClassA.create_instance('name3')
|
||||
assert message_hub.instance_name == 'name3'
|
||||
message_hub = SubClassA.get_instance(current=True)
|
||||
assert message_hub.instance_name == 'name3'
|
||||
# test get root repeatedly.
|
||||
message_hub = SubClassA.get_instance()
|
||||
assert message_hub.instance_name == 'root'
|
||||
# test get name1 repeatedly
|
||||
message_hub = SubClassA.get_instance('name2')
|
||||
assert message_hub.mark == -1
|
||||
# create_instance will raise error if `name` is not specified and
|
||||
# given other arguments
|
||||
with pytest.raises(ValueError):
|
||||
SubClassA.create_instance(a=1)
|
|
@ -17,7 +17,7 @@ class TestLogger:
|
|||
@patch('torch.distributed.is_initialized', lambda: True)
|
||||
@patch('torch.distributed.is_available', lambda: True)
|
||||
def test_init_rank0(self, tmp_path):
|
||||
logger = MMLogger.create_instance('rank0.pkg1', log_level='INFO')
|
||||
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
|
||||
assert logger.name == 'rank0.pkg1'
|
||||
assert logger.instance_name == 'rank0.pkg1'
|
||||
# Logger get from `MMLogger.get_instance` does not inherit from
|
||||
|
@ -30,7 +30,7 @@ class TestLogger:
|
|||
# If `rank=0`, the `log_level` of stream_handler and file_handler
|
||||
# depends on the given arguments.
|
||||
tmp_file = tmp_path / 'tmp_file.log'
|
||||
logger = MMLogger.create_instance(
|
||||
logger = MMLogger.get_instance(
|
||||
'rank0.pkg2', log_level='INFO', log_file=str(tmp_file))
|
||||
assert isinstance(logger, logging.Logger)
|
||||
assert len(logger.handlers) == 2
|
||||
|
@ -47,7 +47,7 @@ class TestLogger:
|
|||
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
|
||||
tmp_file = tmp_path / 'tmp_file.log'
|
||||
log_path = tmp_path / 'rank1_tmp_file.log'
|
||||
logger = MMLogger.create_instance(
|
||||
logger = MMLogger.get_instance(
|
||||
'rank1.pkg2', log_level='INFO', log_file=str(tmp_file))
|
||||
assert len(logger.handlers) == 2
|
||||
assert logger.handlers[0].level == logging.ERROR
|
||||
|
@ -60,7 +60,7 @@ class TestLogger:
|
|||
def test_handler(self, capsys, tmp_path, log_level):
|
||||
# test stream handler can output correct format logs
|
||||
logger_name = f'test_stream_{str(log_level)}'
|
||||
logger = MMLogger.create_instance(logger_name, log_level=log_level)
|
||||
logger = MMLogger.get_instance(logger_name, log_level=log_level)
|
||||
logger.log(level=log_level, msg='welcome')
|
||||
out, _ = capsys.readouterr()
|
||||
# Skip match colored INFO
|
||||
|
@ -73,7 +73,7 @@ class TestLogger:
|
|||
# test file_handler output plain text without color.
|
||||
tmp_file = tmp_path / 'tmp_file.log'
|
||||
logger_name = f'test_file_{log_level}'
|
||||
logger = MMLogger.create_instance(
|
||||
logger = MMLogger.get_instance(
|
||||
logger_name, log_level=log_level, log_file=tmp_file)
|
||||
logger.log(level=log_level, msg='welcome')
|
||||
with open(tmp_file, 'r') as f:
|
||||
|
@ -87,7 +87,7 @@ class TestLogger:
|
|||
def test_erro_format(self, capsys):
|
||||
# test error level log can output file path, function name and
|
||||
# line number
|
||||
logger = MMLogger.create_instance('test_error', log_level='INFO')
|
||||
logger = MMLogger.get_instance('test_error', log_level='INFO')
|
||||
logger.error('welcome')
|
||||
lineno = sys._getframe().f_lineno - 1
|
||||
file_path = __file__
|
||||
|
@ -109,7 +109,7 @@ class TestLogger:
|
|||
print_log('welcome', logger='silent')
|
||||
out, _ = capsys.readouterr()
|
||||
assert out == ''
|
||||
logger = MMLogger.create_instance('test_print_log')
|
||||
logger = MMLogger.get_instance('test_print_log')
|
||||
# Test using specified logger
|
||||
print_log('welcome', logger=logger)
|
||||
out, _ = capsys.readouterr()
|
||||
|
|
|
@ -15,7 +15,7 @@ class TestMessageHub:
|
|||
assert len(message_hub.log_buffers) == 0
|
||||
|
||||
def test_update_log(self):
|
||||
message_hub = MessageHub.create_instance()
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
# test create target `LogBuffer` by name
|
||||
message_hub.update_log('name', 1)
|
||||
log_buffer = message_hub.log_buffers['name']
|
||||
|
@ -26,7 +26,7 @@ class TestMessageHub:
|
|||
# unmatched string will raise a key error
|
||||
|
||||
def test_update_info(self):
|
||||
message_hub = MessageHub.create_instance()
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
# test runtime value can be overwritten.
|
||||
message_hub.update_info('key', 2)
|
||||
assert message_hub.runtime_info['key'] == 2
|
||||
|
@ -34,7 +34,7 @@ class TestMessageHub:
|
|||
assert message_hub.runtime_info['key'] == 1
|
||||
|
||||
def test_get_log_buffers(self):
|
||||
message_hub = MessageHub.create_instance()
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
# Get undefined key will raise error
|
||||
with pytest.raises(KeyError):
|
||||
message_hub.get_log('unknown')
|
||||
|
@ -50,7 +50,7 @@ class TestMessageHub:
|
|||
assert (recorded_count == count).all()
|
||||
|
||||
def test_get_runtime(self):
|
||||
message_hub = MessageHub.create_instance()
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
with pytest.raises(KeyError):
|
||||
message_hub.get_info('unknown')
|
||||
recorded_dict = dict(a=1, b=2)
|
||||
|
@ -58,7 +58,7 @@ class TestMessageHub:
|
|||
assert message_hub.get_info('test_value') == recorded_dict
|
||||
|
||||
def test_get_log_vars(self):
|
||||
message_hub = MessageHub.create_instance()
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
log_dict = dict(
|
||||
loss=1,
|
||||
loss_cls=torch.tensor(2),
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
|
||||
from mmengine.utils import ManagerMeta, ManagerMixin
|
||||
|
||||
|
||||
class SubClassA(ManagerMixin):
|
||||
|
||||
def __init__(self, name='', *args, **kwargs):
|
||||
super().__init__(name, *args, **kwargs)
|
||||
|
||||
|
||||
class SubClassB(ManagerMixin):
|
||||
|
||||
def __init__(self, name='', *args, **kwargs):
|
||||
super().__init__(name, *args, **kwargs)
|
||||
|
||||
|
||||
class TestGlobalMeta:
|
||||
|
||||
def test_init(self):
|
||||
# Subclass's constructor does not contain name arguments will raise an
|
||||
# error.
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
class SubClassNoName1(metaclass=ManagerMeta):
|
||||
|
||||
def __init__(self, a, *args, **kwargs):
|
||||
pass
|
||||
|
||||
# Valid subclass.
|
||||
class GlobalAccessible1(metaclass=ManagerMeta):
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class TestManagerMixin:
|
||||
|
||||
def test_init(self):
|
||||
# test create instance by name.
|
||||
base_cls = ManagerMixin('name')
|
||||
assert base_cls.instance_name == 'name'
|
||||
|
||||
def test_get_instance(self):
|
||||
# SubClass should manage their own `_instance_dict`.
|
||||
with pytest.raises(RuntimeError):
|
||||
SubClassA.get_current_instance()
|
||||
SubClassA.get_instance('instance_a')
|
||||
SubClassB.get_instance('instance_b')
|
||||
assert SubClassB._instance_dict != SubClassA._instance_dict
|
||||
|
||||
# Test `message_hub` can create by name.
|
||||
message_hub = SubClassA.get_instance('name1')
|
||||
assert message_hub.instance_name == 'name1'
|
||||
# No arguments will raise an assertion error.
|
||||
|
||||
SubClassA.get_instance('name2')
|
||||
message_hub = SubClassA.get_current_instance()
|
||||
message_hub.mark = -1
|
||||
assert message_hub.instance_name == 'name2'
|
||||
# Test get latest `message_hub` repeatedly.
|
||||
message_hub = SubClassA.get_instance('name3')
|
||||
assert message_hub.instance_name == 'name3'
|
||||
message_hub = SubClassA.get_current_instance()
|
||||
assert message_hub.instance_name == 'name3'
|
||||
# Test get name2 repeatedly.
|
||||
message_hub = SubClassA.get_instance('name2')
|
||||
assert message_hub.mark == -1
|
||||
# Non-string instance name will raise `AssertionError`.
|
||||
with pytest.raises(AssertionError):
|
||||
SubClassA.get_instance(name=1)
|
|
@ -332,7 +332,7 @@ class TestComposedWriter:
|
|||
assert len(composed_writer._writers) == 2
|
||||
|
||||
# test global
|
||||
composed_writer = ComposedWriter.create_instance(
|
||||
composed_writer = ComposedWriter.get_instance(
|
||||
'composed_writer',
|
||||
writers=[
|
||||
WandbWriter(),
|
||||
|
|
Loading…
Reference in New Issue