[Refactor] Use `MMLogger` from MMEngine in `get_logger` and `print_log`.
parent
088d5b5add
commit
de002e455f
|
@ -4,13 +4,13 @@
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import EpochBasedRunner, get_dist_info
|
||||
from mmcv.utils import print_log
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.hooks import Hook
|
||||
from torch.functional import Tensor
|
||||
from torch.nn import GroupNorm
|
||||
|
@ -52,10 +52,11 @@ def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]:
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def update_bn_stats(model: nn.Module,
|
||||
loader: DataLoader,
|
||||
num_samples: int = 8192,
|
||||
logger: Optional[logging.Logger] = None) -> None:
|
||||
def update_bn_stats(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
num_samples: int = 8192,
|
||||
logger: Optional[Union[logging.Logger, str]] = None) -> None:
|
||||
"""Computes precise BN stats on training data.
|
||||
|
||||
Args:
|
||||
|
@ -63,8 +64,15 @@ def update_bn_stats(model: nn.Module,
|
|||
loader (DataLoader): PyTorch dataloader._dataloader
|
||||
num_samples (int): The number of samples to update the bn stats.
|
||||
Defaults to 8192.
|
||||
logger (:obj:`logging.Logger` | None): Logger for logging.
|
||||
Default: None.
|
||||
logger (logging.Logger or str, optional): If the type of logger is
|
||||
``logging.Logger``, we directly use logger to log messages.
|
||||
Some special loggers are:
|
||||
- "silent": No message will be printed.
|
||||
- "current": Use latest created logger to log message.
|
||||
- other str: Instance name of logger. The corresponding logger
|
||||
will log message if it has been created, otherwise will raise a
|
||||
`ValueError`.
|
||||
- None: The `print()` method will be used to print log messages.
|
||||
"""
|
||||
# get dist info
|
||||
rank, world_size = get_dist_info()
|
||||
|
|
|
@ -4,7 +4,7 @@ import math
|
|||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from mmengine.logging import print_log
|
||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
|
@ -57,8 +57,16 @@ class ConcatDataset(_ConcatDataset):
|
|||
indices (list, optional): The indices of samples corresponding to
|
||||
the results. It's unavailable on ConcatDataset.
|
||||
Defaults to None.
|
||||
logger (logging.Logger | str, optional): Logger used for printing
|
||||
related information during evaluation. Defaults to None.
|
||||
logger (logging.Logger or str, optional): If the type of logger is
|
||||
``logging.Logger``, we directly use logger to log messages.
|
||||
Some special loggers are:
|
||||
- "silent": No message will be printed.
|
||||
- "current": Use latest created logger to log message.
|
||||
- other str: Instance name of logger. The corresponding logger
|
||||
will log message if it has been created, otherwise will raise a
|
||||
`ValueError`.
|
||||
- None: The `print()` method will be used to print log
|
||||
messages.
|
||||
|
||||
Returns:
|
||||
dict[str: float]: AP results of the total dataset or each separate
|
||||
|
|
|
@ -316,7 +316,7 @@ class VisionTransformer(BaseBackbone):
|
|||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if self.pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmcv.utils import print_log
|
||||
from mmengine.logging import print_log
|
||||
logger = get_root_logger()
|
||||
print_log(
|
||||
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from mmcv.utils import get_logger
|
||||
from mmengine.logging import MMLogger
|
||||
|
||||
|
||||
def get_root_logger(log_file=None, log_level=logging.INFO):
|
||||
|
@ -17,7 +17,15 @@ def get_root_logger(log_file=None, log_level=logging.INFO):
|
|||
Returns:
|
||||
:obj:`logging.Logger`: The obtained logger
|
||||
"""
|
||||
return get_logger('mmcls', log_file, log_level)
|
||||
try:
|
||||
return MMLogger.get_instance(
|
||||
'mmcls',
|
||||
logger_name='mmcls',
|
||||
log_file=log_file,
|
||||
log_level=log_level)
|
||||
except AssertionError:
|
||||
# if root logger already existed, no extra kwargs needed.
|
||||
return MMLogger.get_instance('mmcls')
|
||||
|
||||
|
||||
def load_json_log(json_log):
|
||||
|
|
|
@ -5,8 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer
|
||||
from mmcv.utils import get_logger
|
||||
from mmcv.utils.logging import print_log
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmcls.core.hook import PreciseBNHook
|
||||
|
@ -102,7 +101,7 @@ def test_precise_bn():
|
|||
loader = DataLoader(test_dataset, batch_size=2)
|
||||
model = ExampleModel()
|
||||
optimizer = build_optimizer(model, optimizer_cfg)
|
||||
logger = get_logger('precise_bn')
|
||||
logger = MMLogger.get_instance('precise_bn')
|
||||
runner = EpochBasedRunner(
|
||||
model=model,
|
||||
batch_processor=None,
|
||||
|
|
|
@ -3,14 +3,13 @@ import os
|
|||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import mmcv.utils.logging
|
||||
|
||||
from mmengine.logging import MMLogger
|
||||
from mmcls.utils import get_root_logger, load_json_log
|
||||
|
||||
|
||||
def test_get_root_logger():
|
||||
# Reset the initialized log
|
||||
mmcv.utils.logging.logger_initialized = {}
|
||||
# set all logger instance
|
||||
MMLogger._instance_dict = {}
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
log_path = osp.join(tmpdirname, 'test.log')
|
||||
|
||||
|
@ -33,7 +32,6 @@ def test_get_root_logger():
|
|||
for handler in handlers:
|
||||
handler.close()
|
||||
logger.removeHandler(handler)
|
||||
os.remove(log_path)
|
||||
|
||||
|
||||
def test_load_json_log():
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import pkg_resources
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pkg_resources
|
||||
from mmcv import Config, DictAction
|
||||
from mmcv.utils import to_2tuple
|
||||
from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm
|
||||
|
|
Loading…
Reference in New Issue