mmcv/mmcv/runner/utils.py

78 lines
2.1 KiB
Python

import functools
import sys
import time
from getpass import getuser
from socket import gethostname
import mmcv
import torch
import torch.distributed as dist
def get_host_info():
return '{}@{}'.format(getuser(), gethostname())
def get_dist_info():
if torch.__version__ < '1.0':
initialized = dist._initialized
else:
initialized = dist.is_initialized()
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def obj_from_dict(info, parrent=None, default_args=None):
"""Initialize an object from dict.
The dict must contain the key "type", which indicates the object type, it
can be either a string or type, such as "list" or ``list``. Remaining
fields are treated as the arguments for constructing the object.
Args:
info (dict): Object types and arguments.
module (:class:`module`): Module which may containing expected object
classes.
default_args (dict, optional): Default arguments for initializing the
object.
Returns:
any type: Object built from the dict.
"""
assert isinstance(info, dict) and 'type' in info
assert isinstance(default_args, dict) or default_args is None
args = info.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if parrent is not None:
obj_type = getattr(parrent, obj_type)
else:
obj_type = sys.modules[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)