[Enhancement] Replace warnings.warn with print_log (#961)
* Replace warning with print_log * Add comments for testing warningpull/980/head
parent
b3430e4257
commit
dbae83c52f
|
@ -2,15 +2,16 @@
|
|||
import copy
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import warnings
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmengine.fileio import list_from_file, load
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import TRANSFORMS
|
||||
from mmengine.utils import is_abs
|
||||
|
||||
|
@ -100,11 +101,13 @@ def force_full_init(old_func: Callable) -> Any:
|
|||
# `_fully_initialized` is False, call `full_init` and set
|
||||
# `_fully_initialized` to True
|
||||
if not getattr(obj, '_fully_initialized', False):
|
||||
warnings.warn('Attribute `_fully_initialized` is not defined in '
|
||||
f'{type(obj)} or `type(obj)._fully_initialized is '
|
||||
'False, `full_init` will be called and '
|
||||
f'{type(obj)}._fully_initialized will be set to '
|
||||
'True')
|
||||
print_log(
|
||||
f'Attribute `_fully_initialized` is not defined in '
|
||||
f'{type(obj)} or `type(obj)._fully_initialized is '
|
||||
'False, `full_init` will be called and '
|
||||
f'{type(obj)}._fully_initialized will be set to True',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
obj.full_init() # type: ignore
|
||||
obj._fully_initialized = True # type: ignore
|
||||
|
||||
|
@ -392,9 +395,11 @@ class BaseDataset(Dataset):
|
|||
# to manually call `full_init` before dataset fed into dataloader to
|
||||
# ensure all workers use shared RAM from master process.
|
||||
if not self._fully_initialized:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
'Please call `full_init()` method manually to accelerate '
|
||||
'the speed.')
|
||||
'the speed.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.full_init()
|
||||
|
||||
if self.test_mode:
|
||||
|
@ -498,8 +503,11 @@ class BaseDataset(Dataset):
|
|||
try:
|
||||
cls_metainfo[k] = list_from_file(v)
|
||||
except (TypeError, FileNotFoundError):
|
||||
warnings.warn(f'{v} is not a meta file, simply parsed as '
|
||||
'meta information')
|
||||
print_log(
|
||||
f'{v} is not a meta file, simply parsed as meta '
|
||||
'information',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
cls_metainfo[k] = v
|
||||
else:
|
||||
cls_metainfo[k] = v
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import bisect
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
|
||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import DATASETS
|
||||
from .base_dataset import BaseDataset, force_full_init
|
||||
|
||||
|
@ -148,8 +149,11 @@ class ConcatDataset(_ConcatDataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
if not self._fully_initialized:
|
||||
warnings.warn('Please call `full_init` method manually to '
|
||||
'accelerate the speed.')
|
||||
print_log(
|
||||
'Please call `full_init` method manually to '
|
||||
'accelerate the speed.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.full_init()
|
||||
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
|
||||
return self.datasets[dataset_idx][sample_idx]
|
||||
|
@ -263,8 +267,11 @@ class RepeatDataset:
|
|||
|
||||
def __getitem__(self, idx):
|
||||
if not self._fully_initialized:
|
||||
warnings.warn('Please call `full_init` method manually to '
|
||||
'accelerate the speed.')
|
||||
print_log(
|
||||
'Please call `full_init` method manually to accelerate the '
|
||||
'speed.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.full_init()
|
||||
|
||||
sample_idx = self._get_ori_dataset_idx(idx)
|
||||
|
@ -470,9 +477,12 @@ class ClassBalancedDataset:
|
|||
return self.dataset.get_data_info(sample_idx)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
warnings.warn('Please call `full_init` method manually to '
|
||||
'accelerate the speed.')
|
||||
if not self._fully_initialized:
|
||||
print_log(
|
||||
'Please call `full_init` method manually to accelerate '
|
||||
'the speed.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.full_init()
|
||||
|
||||
ori_index = self._get_ori_dataset_idx(idx)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
|
||||
|
@ -44,8 +44,11 @@ class BaseMetric(metaclass=ABCMeta):
|
|||
self.results: List[Any] = []
|
||||
self.prefix = prefix or self.default_prefix
|
||||
if self.prefix is None:
|
||||
warnings.warn('The prefix is not set in metric class '
|
||||
f'{self.__class__.__name__}.')
|
||||
print_log(
|
||||
'The prefix is not set in metric class '
|
||||
f'{self.__class__.__name__}.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
|
||||
@property
|
||||
def dataset_meta(self) -> Optional[dict]:
|
||||
|
@ -97,10 +100,12 @@ class BaseMetric(metaclass=ABCMeta):
|
|||
names of the metrics, and the values are corresponding results.
|
||||
"""
|
||||
if len(self.results) == 0:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
f'{self.__class__.__name__} got empty `self.results`. Please '
|
||||
'ensure that the processed results are properly added into '
|
||||
'`self.results` in `process` method.')
|
||||
'`self.results` in `process` method.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
|
||||
results = collect_results(self.results, size, self.collect_device)
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from mmengine.logging import print_log
|
||||
|
||||
|
||||
class BaseStorageBackend(metaclass=ABCMeta):
|
||||
"""Abstract class of storage backends.
|
||||
|
@ -19,8 +21,10 @@ class BaseStorageBackend(metaclass=ABCMeta):
|
|||
|
||||
@property
|
||||
def allow_symlink(self):
|
||||
warnings.warn('allow_symlink will be deprecated in future',
|
||||
DeprecationWarning)
|
||||
print_log(
|
||||
'allow_symlink will be deprecated in future',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
return self._allow_symlink
|
||||
|
||||
@property
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import inspect
|
||||
import warnings
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Iterator, Optional, Tuple, Union
|
||||
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.utils import is_filepath
|
||||
from .backends import (BaseStorageBackend, HTTPBackend, LmdbBackend,
|
||||
LocalBackend, MemcachedBackend, PetrelBackend)
|
||||
|
@ -14,9 +15,11 @@ class HardDiskBackend(LocalBackend):
|
|||
"""Raw hard disks storage backend."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
'"HardDiskBackend" is the alias of "LocalBackend" '
|
||||
'and the former will be deprecated in future.', DeprecationWarning)
|
||||
'and the former will be deprecated in future.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
|
@ -83,11 +86,12 @@ class FileClient:
|
|||
client: Any
|
||||
|
||||
def __new__(cls, backend=None, prefix=None, **kwargs):
|
||||
warnings.warn(
|
||||
print_log(
|
||||
'"FileClient" will be deprecated in future. Please use io '
|
||||
'functions in '
|
||||
'https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io', # noqa: E501
|
||||
DeprecationWarning)
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
if backend is None and prefix is None:
|
||||
backend = 'disk'
|
||||
if backend is not None and backend not in cls._backends:
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from math import inf
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmengine.dist import is_main_process
|
||||
from mmengine.fileio import FileClient, get_file_backend
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils import is_list_of, is_seq_of
|
||||
from .hook import Hook
|
||||
|
@ -138,9 +139,11 @@ class CheckpointHook(Hook):
|
|||
self.args = kwargs
|
||||
|
||||
if file_client_args is not None:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
'"file_client_args" will be deprecated in future. '
|
||||
'Please use "backend_args" instead', DeprecationWarning)
|
||||
'Please use "backend_args" instead',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
if backend_args is not None:
|
||||
raise ValueError(
|
||||
'"file_client_args" and "backend_args" cannot be set '
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
@ -12,6 +12,7 @@ import torch
|
|||
from mmengine.fileio import FileClient, dump
|
||||
from mmengine.fileio.io import get_file_backend
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils import is_tuple_of, scandir
|
||||
|
||||
|
@ -94,9 +95,11 @@ class LoggerHook(Hook):
|
|||
self.out_dir = out_dir
|
||||
|
||||
if file_client_args is not None:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
'"file_client_args" will be deprecated in future. '
|
||||
'Please use "backend_args" instead', DeprecationWarning)
|
||||
'Please use "backend_args" instead',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
if backend_args is not None:
|
||||
raise ValueError(
|
||||
'"file_client_args" and "backend_args" cannot be set '
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import HOOKS
|
||||
|
||||
|
||||
|
@ -18,7 +19,7 @@ def check_kineto() -> bool: # noqa
|
|||
if torch.autograd.kineto_available():
|
||||
kineto_exist = True
|
||||
except AttributeError:
|
||||
warnings.warn('NO KINETO')
|
||||
print_log('NO KINETO', logger='current', level=logging.WARNING)
|
||||
return kineto_exist
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
@ -8,6 +8,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import MODELS
|
||||
|
||||
|
||||
|
@ -184,11 +185,13 @@ class ExponentialMovingAverage(BaseAveragedModel):
|
|||
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
|
||||
f'but got {momentum}'
|
||||
if momentum > 0.5:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
'The value of momentum in EMA is usually a small number,'
|
||||
'which is different from the conventional notion of '
|
||||
f'momentum but got {momentum}. Please make sure the '
|
||||
f'value is correct.')
|
||||
f'value is correct.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.momentum = momentum
|
||||
|
||||
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABCMeta
|
||||
from collections import defaultdict
|
||||
from logging import FileHandler
|
||||
|
@ -139,8 +138,11 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
|||
initialize(self, pretrained_cfg)
|
||||
self._is_init = True
|
||||
else:
|
||||
warnings.warn(f'init_weights of {self.__class__.__name__} has '
|
||||
f'been called more than once.')
|
||||
print_log(
|
||||
f'init_weights of {self.__class__.__name__} has '
|
||||
f'been called more than once.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
|
||||
if is_top_level_module:
|
||||
# self._dump_init_info(logger_name)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -205,8 +205,11 @@ class DefaultOptimWrapperConstructor:
|
|||
for name, param in module.named_parameters(recurse=False):
|
||||
param_group = {'params': [param]}
|
||||
if bypass_duplicate and self._is_in(param_group, params):
|
||||
warnings.warn(f'{prefix} is duplicate. It is skipped since '
|
||||
f'bypass_duplicate={bypass_duplicate}')
|
||||
print_log(
|
||||
f'{prefix} is duplicate. It is skipped since '
|
||||
f'bypass_duplicate={bypass_duplicate}',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
continue
|
||||
if not param.requires_grad:
|
||||
params.append(param_group)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import logging
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.fileio import dump
|
||||
|
@ -107,10 +107,13 @@ def init_default_scope(scope: str) -> None:
|
|||
return
|
||||
current_scope = DefaultScope.get_current_instance() # type: ignore
|
||||
if current_scope.scope_name != scope: # type: ignore
|
||||
warnings.warn('The current default scope ' # type: ignore
|
||||
f'"{current_scope.scope_name}" is not "{scope}", '
|
||||
'`init_default_scope` will force set the current'
|
||||
f'default scope to "{scope}".')
|
||||
print_log(
|
||||
'The current default scope ' # type: ignore
|
||||
f'"{current_scope.scope_name}" is not "{scope}", '
|
||||
'`init_default_scope` will force set the current'
|
||||
f'default scope to "{scope}".',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
# avoid name conflict
|
||||
new_instance_name = f'{scope}-{datetime.datetime.now()}'
|
||||
DefaultScope.get_instance(new_instance_name, scope_name=scope)
|
||||
|
|
|
@ -5,7 +5,6 @@ import os
|
|||
import os.path as osp
|
||||
import pkgutil
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from importlib import import_module
|
||||
from tempfile import TemporaryDirectory
|
||||
|
@ -425,9 +424,11 @@ def load_from_torchvision(filename, map_location=None):
|
|||
"""
|
||||
model_urls = get_torchvision_models()
|
||||
if filename.startswith('modelzoo://'):
|
||||
warnings.warn(
|
||||
print_log(
|
||||
'The URL scheme of "modelzoo://" is deprecated, please '
|
||||
'use "torchvision://" instead', DeprecationWarning)
|
||||
'use "torchvision://" instead',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
model_name = filename[11:]
|
||||
else:
|
||||
model_name = filename[14:]
|
||||
|
@ -459,10 +460,11 @@ def load_from_openmmlab(filename, map_location=None):
|
|||
|
||||
deprecated_urls = get_deprecated_model_names()
|
||||
if model_name in deprecated_urls:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
f'{prefix_str}{model_name} is deprecated in favor '
|
||||
f'of {prefix_str}{deprecated_urls[model_name]}',
|
||||
DeprecationWarning)
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
model_name = deprecated_urls[model_name]
|
||||
model_url = model_urls[model_name]
|
||||
# check if is url
|
||||
|
@ -715,9 +717,11 @@ def save_checkpoint(checkpoint,
|
|||
New in v0.2.0.
|
||||
"""
|
||||
if file_client_args is not None:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
'"file_client_args" will be deprecated in future. '
|
||||
'Please use "backend_args" instead', DeprecationWarning)
|
||||
'Please use "backend_args" instead',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
if backend_args is not None:
|
||||
raise ValueError(
|
||||
'"file_client_args" and "backend_args" cannot be set '
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import bisect
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import LOOPS
|
||||
from .amp import autocast
|
||||
from .base_loop import BaseLoop
|
||||
|
@ -56,10 +57,12 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
self.runner.visualizer.dataset_meta = \
|
||||
self.dataloader.dataset.metainfo
|
||||
else:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||
'metainfo. ``dataset_meta`` in visualizer will be '
|
||||
'None.')
|
||||
'None.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
|
||||
self.dynamic_milestones, self.dynamic_intervals = \
|
||||
calc_dynamic_intervals(
|
||||
|
@ -160,11 +163,14 @@ class _InfiniteDataloaderIterator:
|
|||
try:
|
||||
data = next(self._iterator)
|
||||
except StopIteration:
|
||||
warnings.warn('Reach the end of the dataloader, it will be '
|
||||
'restarted and continue to iterate. It is '
|
||||
'recommended to use '
|
||||
'`mmengine.dataset.InfiniteSampler` to enable the '
|
||||
'dataloader to iterate infinitely.')
|
||||
print_log(
|
||||
'Reach the end of the dataloader, it will be '
|
||||
'restarted and continue to iterate. It is '
|
||||
'recommended to use '
|
||||
'`mmengine.dataset.InfiniteSampler` to enable the '
|
||||
'dataloader to iterate infinitely.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self._epoch += 1
|
||||
if hasattr(self._dataloader, 'sampler') and hasattr(
|
||||
self._dataloader.sampler, 'set_epoch'):
|
||||
|
@ -226,10 +232,12 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
self.runner.visualizer.dataset_meta = \
|
||||
self.dataloader.dataset.metainfo
|
||||
else:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||
'metainfo. ``dataset_meta`` in visualizer will be '
|
||||
'None.')
|
||||
'None.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
# get the iterator of the dataloader
|
||||
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
|
||||
|
||||
|
@ -338,10 +346,12 @@ class ValLoop(BaseLoop):
|
|||
self.runner.visualizer.dataset_meta = \
|
||||
self.dataloader.dataset.metainfo
|
||||
else:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
||||
'visualizer will be None.')
|
||||
'visualizer will be None.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.fp16 = fp16
|
||||
|
||||
def run(self) -> dict:
|
||||
|
@ -408,10 +418,12 @@ class TestLoop(BaseLoop):
|
|||
self.runner.visualizer.dataset_meta = \
|
||||
self.dataloader.dataset.metainfo
|
||||
else:
|
||||
warnings.warn(
|
||||
print_log(
|
||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
||||
'visualizer will be None.')
|
||||
'visualizer will be None.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.fp16 = fp16
|
||||
|
||||
def run(self) -> dict:
|
||||
|
|
|
@ -1387,8 +1387,11 @@ class Runner:
|
|||
# `persistent_workers` requires pytorch version >= 1.7
|
||||
if ('persistent_workers' in dataloader_cfg
|
||||
and digit_version(TORCH_VERSION) < digit_version('1.7.0')):
|
||||
warnings.warn('`persistent_workers` is only available when '
|
||||
'pytorch version >= 1.7')
|
||||
print_log(
|
||||
'`persistent_workers` is only available when '
|
||||
'pytorch version >= 1.7',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
dataloader_cfg.pop('persistent_workers')
|
||||
|
||||
# The default behavior of `collat_fn` in dataloader is to
|
||||
|
@ -1956,10 +1959,13 @@ class Runner:
|
|||
current_seed = self._randomness_cfg.get('seed')
|
||||
if resumed_seed is not None and resumed_seed != current_seed:
|
||||
if current_seed is not None:
|
||||
warnings.warn(f'The value of random seed in the '
|
||||
f'checkpoint "{resumed_seed}" is '
|
||||
f'different from the value in '
|
||||
f'`randomness` config "{current_seed}"')
|
||||
print_log(
|
||||
f'The value of random seed in the '
|
||||
f'checkpoint "{resumed_seed}" is '
|
||||
f'different from the value in '
|
||||
f'`randomness` config "{current_seed}"',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self._randomness_cfg.update(seed=resumed_seed)
|
||||
self.set_randomness(**self._randomness_cfg)
|
||||
|
||||
|
@ -1970,11 +1976,13 @@ class Runner:
|
|||
# np.ndarray, which cannot be directly judged as equal or not,
|
||||
# therefore we just compared their dumped results.
|
||||
if pickle.dumps(resumed_dataset_meta) != pickle.dumps(dataset_meta):
|
||||
warnings.warn(
|
||||
print_log(
|
||||
'The dataset metainfo from the resumed checkpoint is '
|
||||
'different from the current training dataset, please '
|
||||
'check the correctness of the checkpoint or the training '
|
||||
'dataset.')
|
||||
'dataset.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
|
||||
self.message_hub.load_state_dict(checkpoint['message_hub'])
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
|||
if TYPE_CHECKING:
|
||||
from matplotlib.font_manager import FontProperties
|
||||
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -13,6 +15,7 @@ import torch.nn.functional as F
|
|||
|
||||
from mmengine.config import Config
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import VISBACKENDS, VISUALIZERS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils import ManagerMixin
|
||||
|
@ -163,8 +166,11 @@ class Visualizer(ManagerMixin):
|
|||
self._vis_backends: Union[Dict, Dict[str, 'BaseVisBackend']] = dict()
|
||||
|
||||
if save_dir is None:
|
||||
warnings.warn('`Visualizer` backend is not initialized '
|
||||
'because save_dir is None.')
|
||||
print_log(
|
||||
'`Visualizer` backend is not initialized '
|
||||
'because save_dir is None.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
elif vis_backends is not None:
|
||||
assert len(vis_backends) > 0, 'empty list'
|
||||
names = [
|
||||
|
|
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.registry import METRICS
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
|
@ -110,7 +111,8 @@ class TestEvaluator(TestCase):
|
|||
# Test empty results
|
||||
cfg = dict(type='ToyMetric', dummy_metrics=dict(accuracy=1.0))
|
||||
evaluator = Evaluator(cfg)
|
||||
with self.assertWarnsRegex(UserWarning, 'got empty `self.results`.'):
|
||||
# Warning should be raised if the results are empty
|
||||
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
|
||||
evaluator.evaluate(0)
|
||||
|
||||
def test_composed_metrics(self):
|
||||
|
@ -185,8 +187,10 @@ class TestEvaluator(TestCase):
|
|||
|
||||
def test_prefix(self):
|
||||
cfg = dict(type='NonPrefixedMetric')
|
||||
with self.assertWarnsRegex(UserWarning, 'The prefix is not set'):
|
||||
_ = Evaluator(cfg)
|
||||
logger = MMLogger.get_current_instance()
|
||||
# Warning should be raised if prefix is not set.
|
||||
with self.assertLogs(logger, 'WARNING'):
|
||||
Evaluator(cfg)
|
||||
|
||||
def test_get_metric_value(self):
|
||||
|
||||
|
|
|
@ -76,10 +76,11 @@ class TestCheckpointHook:
|
|||
|
||||
def test_init(self, tmp_path):
|
||||
# Test file_client_args and backend_args
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match='"file_client_args" will be deprecated in future'):
|
||||
CheckpointHook(file_client_args={'backend': 'disk'})
|
||||
# TODO: Refactor this test case
|
||||
# with pytest.warns(
|
||||
# DeprecationWarning,
|
||||
# match='"file_client_args" will be deprecated in future'):
|
||||
# CheckpointHook(file_client_args={'backend': 'disk'})
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
|
|
|
@ -29,11 +29,12 @@ class TestLoggerHook:
|
|||
LoggerHook(file_client_args=dict(enable_mc=True))
|
||||
|
||||
# test `file_client_args` and `backend_args`
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match='"file_client_args" will be deprecated in future'):
|
||||
logger_hook = LoggerHook(
|
||||
out_dir='tmp.txt', file_client_args={'backend': 'disk'})
|
||||
# TODO Refine this unit test
|
||||
# with pytest.warns(
|
||||
# DeprecationWarning,
|
||||
# match='"file_client_args" will be deprecated in future'):
|
||||
# logger_hook = LoggerHook(
|
||||
# out_dir='tmp.txt', file_client_args={'backend': 'disk'})
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
|
|
|
@ -38,8 +38,8 @@ def test_get_config():
|
|||
|
||||
# Test load mmpose
|
||||
get_config(
|
||||
'mmpose::face/2d_kpt_sview_rgb_img/deeppose/wflw/res50_wflw_256x256'
|
||||
'.py')
|
||||
'mmpose::face_2d_keypoint/topdown_heatmap/wflw/td-hm_hrnetv2-w18_8xb64-60e_wflw-256x256.py' # noqa E501
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
|
|
@ -4,6 +4,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import (ExponentialMovingAverage, MomentumAnnealingEMA,
|
||||
StochasticWeightAverage)
|
||||
from mmengine.testing import assert_allclose
|
||||
|
@ -94,9 +95,9 @@ class TestAveragedModel(TestCase):
|
|||
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
|
||||
ExponentialMovingAverage(model, momentum=3)
|
||||
|
||||
with self.assertWarnsRegex(
|
||||
Warning,
|
||||
'The value of momentum in EMA is usually a small number'):
|
||||
# Warning should be raised if the value of momentum in EMA is
|
||||
# a large number
|
||||
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
|
||||
ExponentialMovingAverage(model, momentum=0.9)
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
|||
from torch.distributed.rpc import is_available
|
||||
|
||||
from mmengine.dist import get_rank
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||
DefaultOptimWrapperConstructor, OptimWrapper,
|
||||
build_optim_wrapper)
|
||||
|
@ -592,10 +593,9 @@ class TestBuilder(TestCase):
|
|||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
|
||||
self.assertWarnsRegex(
|
||||
Warning,
|
||||
'conv3.0 is duplicate. It is skipped since bypass_duplicate=True',
|
||||
lambda: optim_constructor(model))
|
||||
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
|
||||
# Warning should be raised since conv3.0 is a duplicate param.
|
||||
optim_constructor(model)
|
||||
optim_wrapper = optim_constructor(model)
|
||||
model_parameters = list(model.parameters())
|
||||
num_params = 14 if MMCV_FULL_AVAILABLE else 11
|
||||
|
@ -607,10 +607,9 @@ class TestBuilder(TestCase):
|
|||
# test DefaultOptimWrapperConstructor when the params in shared
|
||||
# modules do not require grad
|
||||
model.conv1[0].requires_grad_(False)
|
||||
self.assertWarnsRegex(
|
||||
Warning,
|
||||
'conv3.0 is duplicate. It is skipped since bypass_duplicate=True',
|
||||
lambda: optim_constructor(model))
|
||||
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
|
||||
# Warning should be raised since conv3.0 is a duplicate param.
|
||||
optim_constructor(model)
|
||||
optim_wrapper = optim_constructor(model)
|
||||
model_parameters = list(model.parameters())
|
||||
num_params = 14 if MMCV_FULL_AVAILABLE else 11
|
||||
|
|
|
@ -4,6 +4,7 @@ import os.path as osp
|
|||
from tempfile import TemporaryDirectory
|
||||
from unittest import TestCase, skipIf
|
||||
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.registry import (DefaultScope, Registry,
|
||||
count_registered_modules, init_default_scope,
|
||||
root, traverse_registry_tree)
|
||||
|
@ -75,6 +76,7 @@ class TestUtils(TestCase):
|
|||
# init default scope when another scope is init
|
||||
name = f'test-{datetime.datetime.now()}'
|
||||
DefaultScope.get_instance(name, scope_name='test')
|
||||
with self.assertWarnsRegex(
|
||||
Warning, 'The current default scope "test" is not "mmdet"'):
|
||||
# Warning should be raised since the current
|
||||
# default scope is not 'mmdet'
|
||||
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
|
||||
init_default_scope('mmdet')
|
||||
|
|
|
@ -1478,10 +1478,8 @@ class TestRunner(TestCase):
|
|||
cfg.train_cfg = dict(
|
||||
by_epoch=False, max_iters=12, val_interval=4, val_begin=4)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
with self.assertWarnsRegex(
|
||||
Warning,
|
||||
'Reach the end of the dataloader, it will be restarted and '
|
||||
'continue to iterate.'):
|
||||
# Warning should be raised since the sampler is not InfiniteSampler.
|
||||
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
|
||||
runner.train()
|
||||
|
||||
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
||||
|
@ -2073,11 +2071,8 @@ class TestRunner(TestCase):
|
|||
# ckpt_modified['meta']['seed'] = 123
|
||||
path_modified = osp.join(self.temp_dir, 'modified.pth')
|
||||
torch.save(ckpt_modified, path_modified)
|
||||
with self.assertWarnsRegex(
|
||||
Warning, 'The dataset metainfo from the resumed checkpoint is '
|
||||
'different from the current training dataset, please '
|
||||
'check the correctness of the checkpoint or the training '
|
||||
'dataset.'):
|
||||
# Warning should be raised since dataset_meta is not matched
|
||||
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
|
||||
runner.resume(path_modified)
|
||||
|
||||
# 1.3.3 test resume with unmatched seed
|
||||
|
@ -2085,8 +2080,8 @@ class TestRunner(TestCase):
|
|||
ckpt_modified['meta']['seed'] = 123
|
||||
path_modified = osp.join(self.temp_dir, 'modified.pth')
|
||||
torch.save(ckpt_modified, path_modified)
|
||||
with self.assertWarnsRegex(
|
||||
Warning, 'The value of random seed in the checkpoint'):
|
||||
# Warning should be raised since seed is not matched
|
||||
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
|
||||
runner.resume(path_modified)
|
||||
|
||||
# 1.3.3 test resume with no seed and dataset meta
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from mmengine import VISBACKENDS, Config
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
|
||||
|
@ -68,10 +69,8 @@ class TestVisualizer(TestCase):
|
|||
visualizer.get_image()
|
||||
|
||||
# test save_dir
|
||||
with pytest.warns(
|
||||
Warning,
|
||||
match='`Visualizer` backend is not initialized '
|
||||
'because save_dir is None.'):
|
||||
# Warning should be raised since no backend is initialized.
|
||||
with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'):
|
||||
Visualizer()
|
||||
|
||||
visualizer = Visualizer(
|
||||
|
|
Loading…
Reference in New Issue