mirror of https://github.com/open-mmlab/mmcv.git
78 lines
2.1 KiB
Python
78 lines
2.1 KiB
Python
import functools
|
|
import sys
|
|
import time
|
|
from getpass import getuser
|
|
from socket import gethostname
|
|
|
|
import mmcv
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
def get_host_info():
|
|
return '{}@{}'.format(getuser(), gethostname())
|
|
|
|
|
|
def get_dist_info():
|
|
if torch.__version__ < '1.0':
|
|
initialized = dist._initialized
|
|
else:
|
|
initialized = dist.is_initialized()
|
|
if initialized:
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
else:
|
|
rank = 0
|
|
world_size = 1
|
|
return rank, world_size
|
|
|
|
|
|
def master_only(func):
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
rank, _ = get_dist_info()
|
|
if rank == 0:
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def get_time_str():
|
|
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
|
|
|
|
|
def obj_from_dict(info, parrent=None, default_args=None):
|
|
"""Initialize an object from dict.
|
|
|
|
The dict must contain the key "type", which indicates the object type, it
|
|
can be either a string or type, such as "list" or ``list``. Remaining
|
|
fields are treated as the arguments for constructing the object.
|
|
|
|
Args:
|
|
info (dict): Object types and arguments.
|
|
module (:class:`module`): Module which may containing expected object
|
|
classes.
|
|
default_args (dict, optional): Default arguments for initializing the
|
|
object.
|
|
|
|
Returns:
|
|
any type: Object built from the dict.
|
|
"""
|
|
assert isinstance(info, dict) and 'type' in info
|
|
assert isinstance(default_args, dict) or default_args is None
|
|
args = info.copy()
|
|
obj_type = args.pop('type')
|
|
if mmcv.is_str(obj_type):
|
|
if parrent is not None:
|
|
obj_type = getattr(parrent, obj_type)
|
|
else:
|
|
obj_type = sys.modules[obj_type]
|
|
elif not isinstance(obj_type, type):
|
|
raise TypeError('type must be a str or valid type, but got {}'.format(
|
|
type(obj_type)))
|
|
if default_args is not None:
|
|
for name, value in default_args.items():
|
|
args.setdefault(name, value)
|
|
return obj_type(**args)
|