diff --git a/mmcls/core/hook/precise_bn_hook.py b/mmcls/core/hook/precise_bn_hook.py index f8f85fc9..b8b33f0b 100644 --- a/mmcls/core/hook/precise_bn_hook.py +++ b/mmcls/core/hook/precise_bn_hook.py @@ -4,13 +4,13 @@ import itertools import logging -from typing import List, Optional +from typing import List, Optional, Union import mmcv import torch import torch.nn as nn from mmcv.runner import EpochBasedRunner, get_dist_info -from mmcv.utils import print_log +from mmengine.logging import print_log from mmengine.hooks import Hook from torch.functional import Tensor from torch.nn import GroupNorm @@ -52,10 +52,11 @@ def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]: @torch.no_grad() -def update_bn_stats(model: nn.Module, - loader: DataLoader, - num_samples: int = 8192, - logger: Optional[logging.Logger] = None) -> None: +def update_bn_stats( + model: nn.Module, + loader: DataLoader, + num_samples: int = 8192, + logger: Optional[Union[logging.Logger, str]] = None) -> None: """Computes precise BN stats on training data. Args: @@ -63,8 +64,15 @@ def update_bn_stats(model: nn.Module, loader (DataLoader): PyTorch dataloader._dataloader num_samples (int): The number of samples to update the bn stats. Defaults to 8192. - logger (:obj:`logging.Logger` | None): Logger for logging. - Default: None. + logger (logging.Logger or str, optional): If the type of logger is + ``logging.Logger``, we directly use logger to log messages. + Some special loggers are: + - "silent": No message will be printed. + - "current": Use latest created logger to log message. + - other str: Instance name of logger. The corresponding logger + will log message if it has been created, otherwise will raise a + `ValueError`. + - None: The `print()` method will be used to print log messages. """ # get dist info rank, world_size = get_dist_info() diff --git a/mmcls/datasets/dataset_wrappers.py b/mmcls/datasets/dataset_wrappers.py index 8583db1a..49020322 100644 --- a/mmcls/datasets/dataset_wrappers.py +++ b/mmcls/datasets/dataset_wrappers.py @@ -4,7 +4,7 @@ import math from collections import defaultdict import numpy as np -from mmcv.utils import print_log +from mmengine.logging import print_log from torch.utils.data.dataset import ConcatDataset as _ConcatDataset from mmcls.registry import DATASETS @@ -57,8 +57,16 @@ class ConcatDataset(_ConcatDataset): indices (list, optional): The indices of samples corresponding to the results. It's unavailable on ConcatDataset. Defaults to None. - logger (logging.Logger | str, optional): Logger used for printing - related information during evaluation. Defaults to None. + logger (logging.Logger or str, optional): If the type of logger is + ``logging.Logger``, we directly use logger to log messages. + Some special loggers are: + - "silent": No message will be printed. + - "current": Use latest created logger to log message. + - other str: Instance name of logger. The corresponding logger + will log message if it has been created, otherwise will raise a + `ValueError`. + - None: The `print()` method will be used to print log + messages. Returns: dict[str: float]: AP results of the total dataset or each separate diff --git a/mmcls/models/backbones/vision_transformer.py b/mmcls/models/backbones/vision_transformer.py index 5ff91f06..1e20cc06 100644 --- a/mmcls/models/backbones/vision_transformer.py +++ b/mmcls/models/backbones/vision_transformer.py @@ -316,7 +316,7 @@ class VisionTransformer(BaseBackbone): ckpt_pos_embed_shape = state_dict[name].shape if self.pos_embed.shape != ckpt_pos_embed_shape: - from mmcv.utils import print_log + from mmengine.logging import print_log logger = get_root_logger() print_log( f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' diff --git a/mmcls/utils/logger.py b/mmcls/utils/logger.py index 2d77fcb9..41ca8b85 100644 --- a/mmcls/utils/logger.py +++ b/mmcls/utils/logger.py @@ -3,7 +3,7 @@ import json import logging from collections import defaultdict -from mmcv.utils import get_logger +from mmengine.logging import MMLogger def get_root_logger(log_file=None, log_level=logging.INFO): @@ -17,7 +17,15 @@ def get_root_logger(log_file=None, log_level=logging.INFO): Returns: :obj:`logging.Logger`: The obtained logger """ - return get_logger('mmcls', log_file, log_level) + try: + return MMLogger.get_instance( + 'mmcls', + logger_name='mmcls', + log_file=log_file, + log_level=log_level) + except AssertionError: + # if root logger already existed, no extra kwargs needed. + return MMLogger.get_instance('mmcls') def load_json_log(json_log): diff --git a/tests/test_runtime/test_preciseBN_hook.py b/tests/test_runtime/test_preciseBN_hook.py index d9cd7156..e52f2cc9 100644 --- a/tests/test_runtime/test_preciseBN_hook.py +++ b/tests/test_runtime/test_preciseBN_hook.py @@ -5,8 +5,7 @@ import torch import torch.nn as nn from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer -from mmcv.utils import get_logger -from mmcv.utils.logging import print_log +from mmengine.logging import MMLogger, print_log from torch.utils.data import DataLoader, Dataset from mmcls.core.hook import PreciseBNHook @@ -102,7 +101,7 @@ def test_precise_bn(): loader = DataLoader(test_dataset, batch_size=2) model = ExampleModel() optimizer = build_optimizer(model, optimizer_cfg) - logger = get_logger('precise_bn') + logger = MMLogger.get_instance('precise_bn') runner = EpochBasedRunner( model=model, batch_processor=None, diff --git a/tests/test_utils/test_logger.py b/tests/test_utils/test_logger.py index 97a6fb00..69c9ba75 100644 --- a/tests/test_utils/test_logger.py +++ b/tests/test_utils/test_logger.py @@ -3,14 +3,13 @@ import os import os.path as osp import tempfile -import mmcv.utils.logging - +from mmengine.logging import MMLogger from mmcls.utils import get_root_logger, load_json_log def test_get_root_logger(): - # Reset the initialized log - mmcv.utils.logging.logger_initialized = {} + # set all logger instance + MMLogger._instance_dict = {} with tempfile.TemporaryDirectory() as tmpdirname: log_path = osp.join(tmpdirname, 'test.log') @@ -33,7 +32,6 @@ def test_get_root_logger(): for handler in handlers: handler.close() logger.removeHandler(handler) - os.remove(log_path) def test_load_json_log(): diff --git a/tools/visualizations/vis_cam.py b/tools/visualizations/vis_cam.py index a1fcadac..3e05e5a2 100644 --- a/tools/visualizations/vis_cam.py +++ b/tools/visualizations/vis_cam.py @@ -2,12 +2,12 @@ import argparse import copy import math -import pkg_resources import re from pathlib import Path import mmcv import numpy as np +import pkg_resources from mmcv import Config, DictAction from mmcv.utils import to_2tuple from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm