[Enhancement] Replace warnings.warn with print_log (#961)

* Replace warning with print_log

* Add comments for testing warning
pull/980/head
Mashiro 2023-03-06 17:25:28 +08:00 committed by GitHub
parent b3430e4257
commit dbae83c52f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 206 additions and 125 deletions

View File

@ -2,15 +2,16 @@
import copy
import functools
import gc
import logging
import os.path as osp
import pickle
import warnings
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
from torch.utils.data import Dataset
from mmengine.fileio import list_from_file, load
from mmengine.logging import print_log
from mmengine.registry import TRANSFORMS
from mmengine.utils import is_abs
@ -100,11 +101,13 @@ def force_full_init(old_func: Callable) -> Any:
# `_fully_initialized` is False, call `full_init` and set
# `_fully_initialized` to True
if not getattr(obj, '_fully_initialized', False):
warnings.warn('Attribute `_fully_initialized` is not defined in '
f'{type(obj)} or `type(obj)._fully_initialized is '
'False, `full_init` will be called and '
f'{type(obj)}._fully_initialized will be set to '
'True')
print_log(
f'Attribute `_fully_initialized` is not defined in '
f'{type(obj)} or `type(obj)._fully_initialized is '
'False, `full_init` will be called and '
f'{type(obj)}._fully_initialized will be set to True',
logger='current',
level=logging.WARNING)
obj.full_init() # type: ignore
obj._fully_initialized = True # type: ignore
@ -392,9 +395,11 @@ class BaseDataset(Dataset):
# to manually call `full_init` before dataset fed into dataloader to
# ensure all workers use shared RAM from master process.
if not self._fully_initialized:
warnings.warn(
print_log(
'Please call `full_init()` method manually to accelerate '
'the speed.')
'the speed.',
logger='current',
level=logging.WARNING)
self.full_init()
if self.test_mode:
@ -498,8 +503,11 @@ class BaseDataset(Dataset):
try:
cls_metainfo[k] = list_from_file(v)
except (TypeError, FileNotFoundError):
warnings.warn(f'{v} is not a meta file, simply parsed as '
'meta information')
print_log(
f'{v} is not a meta file, simply parsed as meta '
'information',
logger='current',
level=logging.WARNING)
cls_metainfo[k] = v
else:
cls_metainfo[k] = v

View File

@ -1,13 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import copy
import logging
import math
import warnings
from collections import defaultdict
from typing import List, Sequence, Tuple, Union
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from mmengine.logging import print_log
from mmengine.registry import DATASETS
from .base_dataset import BaseDataset, force_full_init
@ -148,8 +149,11 @@ class ConcatDataset(_ConcatDataset):
def __getitem__(self, idx):
if not self._fully_initialized:
warnings.warn('Please call `full_init` method manually to '
'accelerate the speed.')
print_log(
'Please call `full_init` method manually to '
'accelerate the speed.',
logger='current',
level=logging.WARNING)
self.full_init()
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
return self.datasets[dataset_idx][sample_idx]
@ -263,8 +267,11 @@ class RepeatDataset:
def __getitem__(self, idx):
if not self._fully_initialized:
warnings.warn('Please call `full_init` method manually to '
'accelerate the speed.')
print_log(
'Please call `full_init` method manually to accelerate the '
'speed.',
logger='current',
level=logging.WARNING)
self.full_init()
sample_idx = self._get_ori_dataset_idx(idx)
@ -470,9 +477,12 @@ class ClassBalancedDataset:
return self.dataset.get_data_info(sample_idx)
def __getitem__(self, idx):
warnings.warn('Please call `full_init` method manually to '
'accelerate the speed.')
if not self._fully_initialized:
print_log(
'Please call `full_init` method manually to accelerate '
'the speed.',
logger='current',
level=logging.WARNING)
self.full_init()
ori_index = self._get_ori_dataset_idx(idx)

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import logging
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence, Union
@ -44,8 +44,11 @@ class BaseMetric(metaclass=ABCMeta):
self.results: List[Any] = []
self.prefix = prefix or self.default_prefix
if self.prefix is None:
warnings.warn('The prefix is not set in metric class '
f'{self.__class__.__name__}.')
print_log(
'The prefix is not set in metric class '
f'{self.__class__.__name__}.',
logger='current',
level=logging.WARNING)
@property
def dataset_meta(self) -> Optional[dict]:
@ -97,10 +100,12 @@ class BaseMetric(metaclass=ABCMeta):
names of the metrics, and the values are corresponding results.
"""
if len(self.results) == 0:
warnings.warn(
print_log(
f'{self.__class__.__name__} got empty `self.results`. Please '
'ensure that the processed results are properly added into '
'`self.results` in `process` method.')
'`self.results` in `process` method.',
logger='current',
level=logging.WARNING)
results = collect_results(self.results, size, self.collect_device)

View File

@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import logging
from abc import ABCMeta, abstractmethod
from mmengine.logging import print_log
class BaseStorageBackend(metaclass=ABCMeta):
"""Abstract class of storage backends.
@ -19,8 +21,10 @@ class BaseStorageBackend(metaclass=ABCMeta):
@property
def allow_symlink(self):
warnings.warn('allow_symlink will be deprecated in future',
DeprecationWarning)
print_log(
'allow_symlink will be deprecated in future',
logger='current',
level=logging.WARNING)
return self._allow_symlink
@property

View File

@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import warnings
import logging
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Generator, Iterator, Optional, Tuple, Union
from mmengine.logging import print_log
from mmengine.utils import is_filepath
from .backends import (BaseStorageBackend, HTTPBackend, LmdbBackend,
LocalBackend, MemcachedBackend, PetrelBackend)
@ -14,9 +15,11 @@ class HardDiskBackend(LocalBackend):
"""Raw hard disks storage backend."""
def __init__(self) -> None:
warnings.warn(
print_log(
'"HardDiskBackend" is the alias of "LocalBackend" '
'and the former will be deprecated in future.', DeprecationWarning)
'and the former will be deprecated in future.',
logger='current',
level=logging.WARNING)
@property
def name(self):
@ -83,11 +86,12 @@ class FileClient:
client: Any
def __new__(cls, backend=None, prefix=None, **kwargs):
warnings.warn(
print_log(
'"FileClient" will be deprecated in future. Please use io '
'functions in '
'https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io', # noqa: E501
DeprecationWarning)
logger='current',
level=logging.WARNING)
if backend is None and prefix is None:
backend = 'disk'
if backend is not None and backend not in cls._backends:

View File

@ -1,12 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
import warnings
from math import inf
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union
from mmengine.dist import is_main_process
from mmengine.fileio import FileClient, get_file_backend
from mmengine.logging import print_log
from mmengine.registry import HOOKS
from mmengine.utils import is_list_of, is_seq_of
from .hook import Hook
@ -138,9 +139,11 @@ class CheckpointHook(Hook):
self.args = kwargs
if file_client_args is not None:
warnings.warn(
print_log(
'"file_client_args" will be deprecated in future. '
'Please use "backend_args" instead', DeprecationWarning)
'Please use "backend_args" instead',
logger='current',
level=logging.WARNING)
if backend_args is not None:
raise ValueError(
'"file_client_args" and "backend_args" cannot be set '

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import os.path as osp
import warnings
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Optional, Sequence, Union
@ -12,6 +12,7 @@ import torch
from mmengine.fileio import FileClient, dump
from mmengine.fileio.io import get_file_backend
from mmengine.hooks import Hook
from mmengine.logging import print_log
from mmengine.registry import HOOKS
from mmengine.utils import is_tuple_of, scandir
@ -94,9 +95,11 @@ class LoggerHook(Hook):
self.out_dir = out_dir
if file_client_args is not None:
warnings.warn(
print_log(
'"file_client_args" will be deprecated in future. '
'Please use "backend_args" instead', DeprecationWarning)
'Please use "backend_args" instead',
logger='current',
level=logging.WARNING)
if backend_args is not None:
raise ValueError(
'"file_client_args" and "backend_args" cannot be set '

View File

@ -1,14 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import os.path as osp
import sys
import warnings
from typing import Callable, Optional, Union
import torch
from mmengine.dist import master_only
from mmengine.hooks import Hook
from mmengine.logging import print_log
from mmengine.registry import HOOKS
@ -18,7 +19,7 @@ def check_kineto() -> bool: # noqa
if torch.autograd.kineto_available():
kineto_exist = True
except AttributeError:
warnings.warn('NO KINETO')
print_log('NO KINETO', logger='current', level=logging.WARNING)
return kineto_exist

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import logging
from abc import abstractmethod
from copy import deepcopy
from typing import Optional
@ -8,6 +8,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from mmengine.logging import print_log
from mmengine.registry import MODELS
@ -184,11 +185,13 @@ class ExponentialMovingAverage(BaseAveragedModel):
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
f'but got {momentum}'
if momentum > 0.5:
warnings.warn(
print_log(
'The value of momentum in EMA is usually a small number,'
'which is different from the conventional notion of '
f'momentum but got {momentum}. Please make sure the '
f'value is correct.')
f'value is correct.',
logger='current',
level=logging.WARNING)
self.momentum = momentum
def avg_func(self, averaged_param: Tensor, source_param: Tensor,

View File

@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
@ -139,8 +138,11 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
initialize(self, pretrained_cfg)
self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')
print_log(
f'init_weights of {self.__class__.__name__} has '
f'been called more than once.',
logger='current',
level=logging.WARNING)
if is_top_level_module:
# self._dump_init_info(logger_name)

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import logging
from typing import List, Optional, Union
import torch
@ -205,8 +205,11 @@ class DefaultOptimWrapperConstructor:
for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
if bypass_duplicate and self._is_in(param_group, params):
warnings.warn(f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}')
print_log(
f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}',
logger='current',
level=logging.WARNING)
continue
if not param.requires_grad:
params.append(param_group)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import logging
import os.path as osp
import warnings
from typing import Optional
from mmengine.fileio import dump
@ -107,10 +107,13 @@ def init_default_scope(scope: str) -> None:
return
current_scope = DefaultScope.get_current_instance() # type: ignore
if current_scope.scope_name != scope: # type: ignore
warnings.warn('The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not "{scope}", '
'`init_default_scope` will force set the current'
f'default scope to "{scope}".')
print_log(
'The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not "{scope}", '
'`init_default_scope` will force set the current'
f'default scope to "{scope}".',
logger='current',
level=logging.WARNING)
# avoid name conflict
new_instance_name = f'{scope}-{datetime.datetime.now()}'
DefaultScope.get_instance(new_instance_name, scope_name=scope)

View File

@ -5,7 +5,6 @@ import os
import os.path as osp
import pkgutil
import re
import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
@ -425,9 +424,11 @@ def load_from_torchvision(filename, map_location=None):
"""
model_urls = get_torchvision_models()
if filename.startswith('modelzoo://'):
warnings.warn(
print_log(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead', DeprecationWarning)
'use "torchvision://" instead',
logger='current',
level=logging.WARNING)
model_name = filename[11:]
else:
model_name = filename[14:]
@ -459,10 +460,11 @@ def load_from_openmmlab(filename, map_location=None):
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(
print_log(
f'{prefix_str}{model_name} is deprecated in favor '
f'of {prefix_str}{deprecated_urls[model_name]}',
DeprecationWarning)
logger='current',
level=logging.WARNING)
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
@ -715,9 +717,11 @@ def save_checkpoint(checkpoint,
New in v0.2.0.
"""
if file_client_args is not None:
warnings.warn(
print_log(
'"file_client_args" will be deprecated in future. '
'Please use "backend_args" instead', DeprecationWarning)
'Please use "backend_args" instead',
logger='current',
level=logging.WARNING)
if backend_args is not None:
raise ValueError(
'"file_client_args" and "backend_args" cannot be set '

View File

@ -1,13 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import logging
import time
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch.utils.data import DataLoader
from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.registry import LOOPS
from .amp import autocast
from .base_loop import BaseLoop
@ -56,10 +57,12 @@ class EpochBasedTrainLoop(BaseLoop):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
warnings.warn(
print_log(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in visualizer will be '
'None.')
'None.',
logger='current',
level=logging.WARNING)
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
@ -160,11 +163,14 @@ class _InfiniteDataloaderIterator:
try:
data = next(self._iterator)
except StopIteration:
warnings.warn('Reach the end of the dataloader, it will be '
'restarted and continue to iterate. It is '
'recommended to use '
'`mmengine.dataset.InfiniteSampler` to enable the '
'dataloader to iterate infinitely.')
print_log(
'Reach the end of the dataloader, it will be '
'restarted and continue to iterate. It is '
'recommended to use '
'`mmengine.dataset.InfiniteSampler` to enable the '
'dataloader to iterate infinitely.',
logger='current',
level=logging.WARNING)
self._epoch += 1
if hasattr(self._dataloader, 'sampler') and hasattr(
self._dataloader.sampler, 'set_epoch'):
@ -226,10 +232,12 @@ class IterBasedTrainLoop(BaseLoop):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
warnings.warn(
print_log(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in visualizer will be '
'None.')
'None.',
logger='current',
level=logging.WARNING)
# get the iterator of the dataloader
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
@ -338,10 +346,12 @@ class ValLoop(BaseLoop):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
warnings.warn(
print_log(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in evaluator, metric and '
'visualizer will be None.')
'visualizer will be None.',
logger='current',
level=logging.WARNING)
self.fp16 = fp16
def run(self) -> dict:
@ -408,10 +418,12 @@ class TestLoop(BaseLoop):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
warnings.warn(
print_log(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in evaluator, metric and '
'visualizer will be None.')
'visualizer will be None.',
logger='current',
level=logging.WARNING)
self.fp16 = fp16
def run(self) -> dict:

View File

@ -1387,8 +1387,11 @@ class Runner:
# `persistent_workers` requires pytorch version >= 1.7
if ('persistent_workers' in dataloader_cfg
and digit_version(TORCH_VERSION) < digit_version('1.7.0')):
warnings.warn('`persistent_workers` is only available when '
'pytorch version >= 1.7')
print_log(
'`persistent_workers` is only available when '
'pytorch version >= 1.7',
logger='current',
level=logging.WARNING)
dataloader_cfg.pop('persistent_workers')
# The default behavior of `collat_fn` in dataloader is to
@ -1956,10 +1959,13 @@ class Runner:
current_seed = self._randomness_cfg.get('seed')
if resumed_seed is not None and resumed_seed != current_seed:
if current_seed is not None:
warnings.warn(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{current_seed}"')
print_log(
f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{current_seed}"',
logger='current',
level=logging.WARNING)
self._randomness_cfg.update(seed=resumed_seed)
self.set_randomness(**self._randomness_cfg)
@ -1970,11 +1976,13 @@ class Runner:
# np.ndarray, which cannot be directly judged as equal or not,
# therefore we just compared their dumped results.
if pickle.dumps(resumed_dataset_meta) != pickle.dumps(dataset_meta):
warnings.warn(
print_log(
'The dataset metainfo from the resumed checkpoint is '
'different from the current training dataset, please '
'check the correctness of the checkpoint or the training '
'dataset.')
'dataset.',
logger='current',
level=logging.WARNING)
self.message_hub.load_state_dict(checkpoint['message_hub'])

View File

@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
if TYPE_CHECKING:
from matplotlib.font_manager import FontProperties
import logging
import cv2
import numpy as np
import torch
@ -13,6 +15,7 @@ import torch.nn.functional as F
from mmengine.config import Config
from mmengine.dist import master_only
from mmengine.logging import print_log
from mmengine.registry import VISBACKENDS, VISUALIZERS
from mmengine.structures import BaseDataElement
from mmengine.utils import ManagerMixin
@ -163,8 +166,11 @@ class Visualizer(ManagerMixin):
self._vis_backends: Union[Dict, Dict[str, 'BaseVisBackend']] = dict()
if save_dir is None:
warnings.warn('`Visualizer` backend is not initialized '
'because save_dir is None.')
print_log(
'`Visualizer` backend is not initialized '
'because save_dir is None.',
logger='current',
level=logging.WARNING)
elif vis_backends is not None:
assert len(vis_backends) > 0, 'empty list'
names = [

View File

@ -8,6 +8,7 @@ import numpy as np
import torch
from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value
from mmengine.logging import MMLogger
from mmengine.registry import METRICS
from mmengine.structures import BaseDataElement
@ -110,7 +111,8 @@ class TestEvaluator(TestCase):
# Test empty results
cfg = dict(type='ToyMetric', dummy_metrics=dict(accuracy=1.0))
evaluator = Evaluator(cfg)
with self.assertWarnsRegex(UserWarning, 'got empty `self.results`.'):
# Warning should be raised if the results are empty
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
evaluator.evaluate(0)
def test_composed_metrics(self):
@ -185,8 +187,10 @@ class TestEvaluator(TestCase):
def test_prefix(self):
cfg = dict(type='NonPrefixedMetric')
with self.assertWarnsRegex(UserWarning, 'The prefix is not set'):
_ = Evaluator(cfg)
logger = MMLogger.get_current_instance()
# Warning should be raised if prefix is not set.
with self.assertLogs(logger, 'WARNING'):
Evaluator(cfg)
def test_get_metric_value(self):

View File

@ -76,10 +76,11 @@ class TestCheckpointHook:
def test_init(self, tmp_path):
# Test file_client_args and backend_args
with pytest.warns(
DeprecationWarning,
match='"file_client_args" will be deprecated in future'):
CheckpointHook(file_client_args={'backend': 'disk'})
# TODO: Refactor this test case
# with pytest.warns(
# DeprecationWarning,
# match='"file_client_args" will be deprecated in future'):
# CheckpointHook(file_client_args={'backend': 'disk'})
with pytest.raises(
ValueError,

View File

@ -29,11 +29,12 @@ class TestLoggerHook:
LoggerHook(file_client_args=dict(enable_mc=True))
# test `file_client_args` and `backend_args`
with pytest.warns(
DeprecationWarning,
match='"file_client_args" will be deprecated in future'):
logger_hook = LoggerHook(
out_dir='tmp.txt', file_client_args={'backend': 'disk'})
# TODO Refine this unit test
# with pytest.warns(
# DeprecationWarning,
# match='"file_client_args" will be deprecated in future'):
# logger_hook = LoggerHook(
# out_dir='tmp.txt', file_client_args={'backend': 'disk'})
with pytest.raises(
ValueError,

View File

@ -38,8 +38,8 @@ def test_get_config():
# Test load mmpose
get_config(
'mmpose::face/2d_kpt_sview_rgb_img/deeppose/wflw/res50_wflw_256x256'
'.py')
'mmpose::face_2d_keypoint/topdown_heatmap/wflw/td-hm_hrnetv2-w18_8xb64-60e_wflw-256x256.py' # noqa E501
)
@pytest.mark.skipif(

View File

@ -4,6 +4,7 @@ from unittest import TestCase
import torch
from mmengine.logging import MMLogger
from mmengine.model import (ExponentialMovingAverage, MomentumAnnealingEMA,
StochasticWeightAverage)
from mmengine.testing import assert_allclose
@ -94,9 +95,9 @@ class TestAveragedModel(TestCase):
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
ExponentialMovingAverage(model, momentum=3)
with self.assertWarnsRegex(
Warning,
'The value of momentum in EMA is usually a small number'):
# Warning should be raised if the value of momentum in EMA is
# a large number
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
ExponentialMovingAverage(model, momentum=0.9)

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed.rpc import is_available
from mmengine.dist import get_rank
from mmengine.logging import MMLogger
from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimWrapperConstructor, OptimWrapper,
build_optim_wrapper)
@ -592,10 +593,9 @@ class TestBuilder(TestCase):
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
self.assertWarnsRegex(
Warning,
'conv3.0 is duplicate. It is skipped since bypass_duplicate=True',
lambda: optim_constructor(model))
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
# Warning should be raised since conv3.0 is a duplicate param.
optim_constructor(model)
optim_wrapper = optim_constructor(model)
model_parameters = list(model.parameters())
num_params = 14 if MMCV_FULL_AVAILABLE else 11
@ -607,10 +607,9 @@ class TestBuilder(TestCase):
# test DefaultOptimWrapperConstructor when the params in shared
# modules do not require grad
model.conv1[0].requires_grad_(False)
self.assertWarnsRegex(
Warning,
'conv3.0 is duplicate. It is skipped since bypass_duplicate=True',
lambda: optim_constructor(model))
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
# Warning should be raised since conv3.0 is a duplicate param.
optim_constructor(model)
optim_wrapper = optim_constructor(model)
model_parameters = list(model.parameters())
num_params = 14 if MMCV_FULL_AVAILABLE else 11

View File

@ -4,6 +4,7 @@ import os.path as osp
from tempfile import TemporaryDirectory
from unittest import TestCase, skipIf
from mmengine.logging import MMLogger
from mmengine.registry import (DefaultScope, Registry,
count_registered_modules, init_default_scope,
root, traverse_registry_tree)
@ -75,6 +76,7 @@ class TestUtils(TestCase):
# init default scope when another scope is init
name = f'test-{datetime.datetime.now()}'
DefaultScope.get_instance(name, scope_name='test')
with self.assertWarnsRegex(
Warning, 'The current default scope "test" is not "mmdet"'):
# Warning should be raised since the current
# default scope is not 'mmdet'
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
init_default_scope('mmdet')

View File

@ -1478,10 +1478,8 @@ class TestRunner(TestCase):
cfg.train_cfg = dict(
by_epoch=False, max_iters=12, val_interval=4, val_begin=4)
runner = Runner.from_cfg(cfg)
with self.assertWarnsRegex(
Warning,
'Reach the end of the dataloader, it will be restarted and '
'continue to iterate.'):
# Warning should be raised since the sampler is not InfiniteSampler.
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
runner.train()
assert isinstance(runner.train_loop, IterBasedTrainLoop)
@ -2073,11 +2071,8 @@ class TestRunner(TestCase):
# ckpt_modified['meta']['seed'] = 123
path_modified = osp.join(self.temp_dir, 'modified.pth')
torch.save(ckpt_modified, path_modified)
with self.assertWarnsRegex(
Warning, 'The dataset metainfo from the resumed checkpoint is '
'different from the current training dataset, please '
'check the correctness of the checkpoint or the training '
'dataset.'):
# Warning should be raised since dataset_meta is not matched
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
runner.resume(path_modified)
# 1.3.3 test resume with unmatched seed
@ -2085,8 +2080,8 @@ class TestRunner(TestCase):
ckpt_modified['meta']['seed'] = 123
path_modified = osp.join(self.temp_dir, 'modified.pth')
torch.save(ckpt_modified, path_modified)
with self.assertWarnsRegex(
Warning, 'The value of random seed in the checkpoint'):
# Warning should be raised since seed is not matched
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
runner.resume(path_modified)
# 1.3.3 test resume with no seed and dataset meta

View File

@ -10,6 +10,7 @@ import torch
import torch.nn as nn
from mmengine import VISBACKENDS, Config
from mmengine.logging import MMLogger
from mmengine.visualization import Visualizer
@ -68,10 +69,8 @@ class TestVisualizer(TestCase):
visualizer.get_image()
# test save_dir
with pytest.warns(
Warning,
match='`Visualizer` backend is not initialized '
'because save_dir is None.'):
# Warning should be raised since no backend is initialized.
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
Visualizer()
visualizer = Visualizer(