[Refactor] Apply get_current_instance instead of get_root_logger (#394)
parent
c8cf491c4c
commit
045b1fde8e
|
@ -4,12 +4,12 @@ from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmengine.dist import get_dist_info
|
from mmengine.dist import get_dist_info
|
||||||
|
from mmengine.logging import MMLogger
|
||||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmselfsup.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
from mmselfsup.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
||||||
OPTIMIZERS)
|
OPTIMIZERS)
|
||||||
from mmselfsup.utils import get_root_logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
|
def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
|
||||||
|
@ -88,7 +88,7 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
||||||
submodule of DCN, `is_dcn_module` will be passed to
|
submodule of DCN, `is_dcn_module` will be passed to
|
||||||
control conv_offset layer's learning rate. Defaults to None.
|
control conv_offset layer's learning rate. Defaults to None.
|
||||||
"""
|
"""
|
||||||
logger = get_root_logger()
|
logger = MMLogger.get_current_instance()
|
||||||
|
|
||||||
model_type = optimizer_cfg.pop('model_type', None)
|
model_type = optimizer_cfg.pop('model_type', None)
|
||||||
# model_type should not be None
|
# model_type should not be None
|
||||||
|
|
|
@ -5,13 +5,11 @@ from .collect import dist_forward_collect, nondist_forward_collect
|
||||||
from .collect_env import collect_env
|
from .collect_env import collect_env
|
||||||
from .distributed_sinkhorn import distributed_sinkhorn
|
from .distributed_sinkhorn import distributed_sinkhorn
|
||||||
from .gather import concat_all_gather, gather_tensors, gather_tensors_batch
|
from .gather import concat_all_gather, gather_tensors, gather_tensors_batch
|
||||||
from .logger import get_root_logger
|
|
||||||
from .setup_env import register_all_modules
|
from .setup_env import register_all_modules
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp',
|
'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp',
|
||||||
'dist_forward_collect', 'nondist_forward_collect', 'collect_env',
|
'dist_forward_collect', 'nondist_forward_collect', 'collect_env',
|
||||||
'sync_random_seed', 'distributed_sinkhorn', 'concat_all_gather',
|
'sync_random_seed', 'distributed_sinkhorn', 'concat_all_gather',
|
||||||
'gather_tensors', 'gather_tensors_batch', 'get_root_logger',
|
'gather_tensors', 'gather_tensors_batch', 'register_all_modules'
|
||||||
'register_all_modules'
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from mmengine.logging import MMLogger
|
|
||||||
|
|
||||||
|
|
||||||
def get_root_logger(log_file: str = None,
|
|
||||||
log_level: int = logging.INFO) -> logging.Logger:
|
|
||||||
"""Get root logger.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
log_file (str, optional): File path of log. Defaults to None.
|
|
||||||
log_level (int, optional): The level of logger.
|
|
||||||
Defaults to :obj:`logging.INFO`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:obj:`logging.Logger`: The obtained logger
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return MMLogger.get_instance(
|
|
||||||
'mmselfsup',
|
|
||||||
logger_name='mmselfsup',
|
|
||||||
log_file=log_file,
|
|
||||||
log_level=log_level)
|
|
||||||
except AssertionError:
|
|
||||||
# if root logger already existed, no extra kwargs needed.
|
|
||||||
return MMLogger.get_instance('mmselfsup')
|
|
|
@ -1,34 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import os.path as osp
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from mmengine.logging import MMLogger
|
|
||||||
|
|
||||||
from mmselfsup.utils import get_root_logger
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_root_logger():
|
|
||||||
# set all logger instance
|
|
||||||
MMLogger._instance_dict = {}
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
log_path = osp.join(tmpdirname, 'test.log')
|
|
||||||
|
|
||||||
logger = get_root_logger(log_file=log_path)
|
|
||||||
message1 = 'adhsuadghj'
|
|
||||||
logger.info(message1)
|
|
||||||
|
|
||||||
logger2 = get_root_logger()
|
|
||||||
message2 = 'm,tkrgmkr'
|
|
||||||
logger2.info(message2)
|
|
||||||
|
|
||||||
with open(log_path, 'r') as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
assert message1 in lines[0]
|
|
||||||
assert message2 in lines[1]
|
|
||||||
|
|
||||||
assert logger is logger2
|
|
||||||
|
|
||||||
handlers = list(logger.handlers)
|
|
||||||
for handler in handlers:
|
|
||||||
handler.close()
|
|
||||||
logger.removeHandler(handler)
|
|
|
@ -11,6 +11,7 @@ import torch
|
||||||
from mmengine.config import Config, DictAction
|
from mmengine.config import Config, DictAction
|
||||||
from mmengine.data import pseudo_collate, worker_init_fn
|
from mmengine.data import pseudo_collate, worker_init_fn
|
||||||
from mmengine.dist import get_rank, init_dist
|
from mmengine.dist import get_rank, init_dist
|
||||||
|
from mmengine.logging import MMLogger
|
||||||
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
||||||
from mmengine.runner import load_checkpoint
|
from mmengine.runner import load_checkpoint
|
||||||
from mmengine.utils import mkdir_or_exist
|
from mmengine.utils import mkdir_or_exist
|
||||||
|
@ -19,7 +20,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from mmselfsup.models.utils import Extractor
|
from mmselfsup.models.utils import Extractor
|
||||||
from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS
|
from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS
|
||||||
from mmselfsup.utils import get_root_logger, register_all_modules
|
from mmselfsup.utils import register_all_modules
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -136,7 +137,11 @@ def main():
|
||||||
|
|
||||||
# init the logger before other steps
|
# init the logger before other steps
|
||||||
log_file = osp.join(tsne_work_dir, 'extract.log')
|
log_file = osp.join(tsne_work_dir, 'extract.log')
|
||||||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
logger = MMLogger.get_instance(
|
||||||
|
'mmselfsup',
|
||||||
|
logger_name='mmselfsup',
|
||||||
|
log_file=log_file,
|
||||||
|
log_level=cfg.log_level)
|
||||||
|
|
||||||
# build the dataset
|
# build the dataset
|
||||||
dataset_cfg = Config.fromfile(args.dataset_config)
|
dataset_cfg = Config.fromfile(args.dataset_config)
|
||||||
|
|
|
@ -11,6 +11,7 @@ import torch
|
||||||
from mmengine.config import Config, DictAction
|
from mmengine.config import Config, DictAction
|
||||||
from mmengine.data import pseudo_collate, worker_init_fn
|
from mmengine.data import pseudo_collate, worker_init_fn
|
||||||
from mmengine.dist import get_rank, init_dist
|
from mmengine.dist import get_rank, init_dist
|
||||||
|
from mmengine.logging import MMLogger
|
||||||
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
||||||
from mmengine.runner import load_checkpoint
|
from mmengine.runner import load_checkpoint
|
||||||
from mmengine.utils import mkdir_or_exist
|
from mmengine.utils import mkdir_or_exist
|
||||||
|
@ -18,7 +19,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from mmselfsup.models.utils import Extractor
|
from mmselfsup.models.utils import Extractor
|
||||||
from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS
|
from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS
|
||||||
from mmselfsup.utils import get_root_logger, register_all_modules
|
from mmselfsup.utils import register_all_modules
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -100,7 +101,11 @@ def main():
|
||||||
# init the logger before other steps
|
# init the logger before other steps
|
||||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||||
log_file = osp.join(cfg.work_dir, f'extract_{timestamp}.log')
|
log_file = osp.join(cfg.work_dir, f'extract_{timestamp}.log')
|
||||||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
logger = MMLogger.get_instance(
|
||||||
|
'mmselfsup',
|
||||||
|
logger_name='mmselfsup',
|
||||||
|
log_file=log_file,
|
||||||
|
log_level=cfg.log_level)
|
||||||
|
|
||||||
# build the dataset
|
# build the dataset
|
||||||
dataset_cfg = Config.fromfile(args.dataset_config)
|
dataset_cfg = Config.fromfile(args.dataset_config)
|
||||||
|
|
Loading…
Reference in New Issue