[Refactor] Use `MMLogger` from MMEngine in `get_logger` and `print_log`.

pull/913/head
yingfhu 2022-05-11 08:26:37 +00:00 committed by mzr1996
parent 088d5b5add
commit de002e455f
7 changed files with 44 additions and 23 deletions

View File

@ -4,13 +4,13 @@
import itertools
import logging
from typing import List, Optional
from typing import List, Optional, Union
import mmcv
import torch
import torch.nn as nn
from mmcv.runner import EpochBasedRunner, get_dist_info
from mmcv.utils import print_log
from mmengine.logging import print_log
from mmengine.hooks import Hook
from torch.functional import Tensor
from torch.nn import GroupNorm
@ -52,10 +52,11 @@ def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]:
@torch.no_grad()
def update_bn_stats(model: nn.Module,
loader: DataLoader,
num_samples: int = 8192,
logger: Optional[logging.Logger] = None) -> None:
def update_bn_stats(
model: nn.Module,
loader: DataLoader,
num_samples: int = 8192,
logger: Optional[Union[logging.Logger, str]] = None) -> None:
"""Computes precise BN stats on training data.
Args:
@ -63,8 +64,15 @@ def update_bn_stats(model: nn.Module,
loader (DataLoader): PyTorch dataloader._dataloader
num_samples (int): The number of samples to update the bn stats.
Defaults to 8192.
logger (:obj:`logging.Logger` | None): Logger for logging.
Default: None.
logger (logging.Logger or str, optional): If the type of logger is
``logging.Logger``, we directly use logger to log messages.
Some special loggers are:
- "silent": No message will be printed.
- "current": Use latest created logger to log message.
- other str: Instance name of logger. The corresponding logger
will log message if it has been created, otherwise will raise a
`ValueError`.
- None: The `print()` method will be used to print log messages.
"""
# get dist info
rank, world_size = get_dist_info()

View File

@ -4,7 +4,7 @@ import math
from collections import defaultdict
import numpy as np
from mmcv.utils import print_log
from mmengine.logging import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from mmcls.registry import DATASETS
@ -57,8 +57,16 @@ class ConcatDataset(_ConcatDataset):
indices (list, optional): The indices of samples corresponding to
the results. It's unavailable on ConcatDataset.
Defaults to None.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
logger (logging.Logger or str, optional): If the type of logger is
``logging.Logger``, we directly use logger to log messages.
Some special loggers are:
- "silent": No message will be printed.
- "current": Use latest created logger to log message.
- other str: Instance name of logger. The corresponding logger
will log message if it has been created, otherwise will raise a
`ValueError`.
- None: The `print()` method will be used to print log
messages.
Returns:
dict[str: float]: AP results of the total dataset or each separate

View File

@ -316,7 +316,7 @@ class VisionTransformer(BaseBackbone):
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmcv.utils import print_log
from mmengine.logging import print_log
logger = get_root_logger()
print_log(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '

View File

@ -3,7 +3,7 @@ import json
import logging
from collections import defaultdict
from mmcv.utils import get_logger
from mmengine.logging import MMLogger
def get_root_logger(log_file=None, log_level=logging.INFO):
@ -17,7 +17,15 @@ def get_root_logger(log_file=None, log_level=logging.INFO):
Returns:
:obj:`logging.Logger`: The obtained logger
"""
return get_logger('mmcls', log_file, log_level)
try:
return MMLogger.get_instance(
'mmcls',
logger_name='mmcls',
log_file=log_file,
log_level=log_level)
except AssertionError:
# if root logger already existed, no extra kwargs needed.
return MMLogger.get_instance('mmcls')
def load_json_log(json_log):

View File

@ -5,8 +5,7 @@ import torch
import torch.nn as nn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer
from mmcv.utils import get_logger
from mmcv.utils.logging import print_log
from mmengine.logging import MMLogger, print_log
from torch.utils.data import DataLoader, Dataset
from mmcls.core.hook import PreciseBNHook
@ -102,7 +101,7 @@ def test_precise_bn():
loader = DataLoader(test_dataset, batch_size=2)
model = ExampleModel()
optimizer = build_optimizer(model, optimizer_cfg)
logger = get_logger('precise_bn')
logger = MMLogger.get_instance('precise_bn')
runner = EpochBasedRunner(
model=model,
batch_processor=None,

View File

@ -3,14 +3,13 @@ import os
import os.path as osp
import tempfile
import mmcv.utils.logging
from mmengine.logging import MMLogger
from mmcls.utils import get_root_logger, load_json_log
def test_get_root_logger():
# Reset the initialized log
mmcv.utils.logging.logger_initialized = {}
# set all logger instance
MMLogger._instance_dict = {}
with tempfile.TemporaryDirectory() as tmpdirname:
log_path = osp.join(tmpdirname, 'test.log')
@ -33,7 +32,6 @@ def test_get_root_logger():
for handler in handlers:
handler.close()
logger.removeHandler(handler)
os.remove(log_path)
def test_load_json_log():

View File

@ -2,12 +2,12 @@
import argparse
import copy
import math
import pkg_resources
import re
from pathlib import Path
import mmcv
import numpy as np
import pkg_resources
from mmcv import Config, DictAction
from mmcv.utils import to_2tuple
from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm