[Fix] Fix lint (#1598)

* [Fix] Fix lint

* [Fix] Fix lint

* Update mmengine/dist/utils.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
Mashiro 2024-11-02 22:23:51 +08:00 committed by GitHub
parent c9b59962d6
commit cc3b74b5e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 146 additions and 134 deletions

View File

@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python 3.7 - name: Set up Python 3.10.15
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
python-version: 3.7 python-version: '3.10.15'
- name: Install pre-commit hook - name: Install pre-commit hook
run: | run: |
pip install pre-commit pip install pre-commit

View File

@ -1,5 +1,9 @@
name: pr_stage_test name: pr_stage_test
env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true
on: on:
pull_request: pull_request:
paths-ignore: paths-ignore:

View File

@ -1,7 +1,11 @@
exclude: ^tests/data/ exclude: ^tests/data/
repos: repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8 - repo: https://github.com/pre-commit/pre-commit
rev: 5.0.4 rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://gitee.com/openmmlab/mirrors-isort - repo: https://gitee.com/openmmlab/mirrors-isort
@ -13,7 +17,7 @@ repos:
hooks: hooks:
- id: yapf - id: yapf
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
rev: v4.3.0 rev: v5.0.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: check-yaml - id: check-yaml
@ -55,7 +59,7 @@ repos:
args: ["mmengine", "tests"] args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs - id: remove-improper-eol-in-cn-docs
- repo: https://gitee.com/openmmlab/mirrors-mypy - repo: https://gitee.com/openmmlab/mirrors-mypy
rev: v0.812 rev: v1.2.0
hooks: hooks:
- id: mypy - id: mypy
exclude: |- exclude: |-
@ -63,3 +67,4 @@ repos:
^examples ^examples
| ^docs | ^docs
) )
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]

View File

@ -1,7 +1,11 @@
exclude: ^tests/data/ exclude: ^tests/data/
repos: repos:
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 5.0.4 rev: 7.1.1
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/PyCQA/isort - repo: https://github.com/PyCQA/isort
@ -13,7 +17,7 @@ repos:
hooks: hooks:
- id: yapf - id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0 rev: v5.0.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: check-yaml - id: check-yaml
@ -34,12 +38,8 @@ repos:
- mdformat-openmmlab - mdformat-openmmlab
- mdformat_frontmatter - mdformat_frontmatter
- linkify-it-py - linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter - repo: https://github.com/myint/docformatter
rev: v1.3.1 rev: 06907d0
hooks: hooks:
- id: docformatter - id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"] args: ["--in-place", "--wrap-descriptions", "79"]
@ -55,7 +55,7 @@ repos:
args: ["mmengine", "tests"] args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs - id: remove-improper-eol-in-cn-docs
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812 rev: v1.2.0
hooks: hooks:
- id: mypy - id: mypy
exclude: |- exclude: |-
@ -63,3 +63,4 @@ repos:
^examples ^examples
| ^docs | ^docs
) )
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]

View File

@ -499,7 +499,7 @@ class BaseStrategy(metaclass=ABCMeta):
'"type" and "constructor" are not in ' '"type" and "constructor" are not in '
f'optimizer, but got {name}={optim}') f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers) return OptimWrapperDict(**optim_wrappers) # type: ignore
else: else:
raise TypeError('optimizer wrapper should be an OptimWrapper ' raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}') f'object or dict, but got {optim_wrapper}')

View File

@ -361,7 +361,7 @@ class ColossalAIStrategy(BaseStrategy):
map_location: Union[str, Callable] = 'default', map_location: Union[str, Callable] = 'default',
callback: Optional[Callable] = None, callback: Optional[Callable] = None,
) -> dict: ) -> dict:
"""override this method since colossalai resume optimizer from filename """Override this method since colossalai resume optimizer from filename
directly.""" directly."""
self.logger.info(f'Resume checkpoint from {filename}') self.logger.info(f'Resume checkpoint from {filename}')

View File

@ -53,7 +53,7 @@ class DDPStrategy(SingleDeviceStrategy):
init_dist(launcher, backend, **kwargs) init_dist(launcher, backend, **kwargs)
def convert_model(self, model: nn.Module) -> nn.Module: def convert_model(self, model: nn.Module) -> nn.Module:
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` """Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers. (SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.
Args: Args:

View File

@ -393,7 +393,7 @@ class Config:
def __init__( def __init__(
self, self,
cfg_dict: dict = None, cfg_dict: Optional[dict] = None,
cfg_text: Optional[str] = None, cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None, filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None, env_variables: Optional[dict] = None,
@ -1227,7 +1227,8 @@ class Config:
if base_code is not None: if base_code is not None:
base_code = ast.Expression( # type: ignore base_code = ast.Expression( # type: ignore
body=base_code.value) # type: ignore body=base_code.value) # type: ignore
base_files = eval(compile(base_code, '', mode='eval')) base_files = eval(compile(base_code, '',
mode='eval')) # type: ignore
else: else:
base_files = [] base_files = []
elif file_format in ('.yml', '.yaml', '.json'): elif file_format in ('.yml', '.yaml', '.json'):
@ -1288,7 +1289,7 @@ class Config:
def _merge_a_into_b(a: dict, def _merge_a_into_b(a: dict,
b: dict, b: dict,
allow_list_keys: bool = False) -> dict: allow_list_keys: bool = False) -> dict:
"""merge dict ``a`` into dict ``b`` (non-inplace). """Merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications. in-place modifications.
@ -1358,22 +1359,22 @@ class Config:
@property @property
def filename(self) -> str: def filename(self) -> str:
"""get file name of config.""" """Get file name of config."""
return self._filename return self._filename
@property @property
def text(self) -> str: def text(self) -> str:
"""get config text.""" """Get config text."""
return self._text return self._text
@property @property
def env_variables(self) -> dict: def env_variables(self) -> dict:
"""get used environment variables.""" """Get used environment variables."""
return self._env_variables return self._env_variables
@property @property
def pretty_text(self) -> str: def pretty_text(self) -> str:
"""get formatted python config text.""" """Get formatted python config text."""
indent = 4 indent = 4
@ -1727,17 +1728,17 @@ class Config:
class DictAction(Action): class DictAction(Action):
""" """Argparse action to split an argument into KEY=VALUE form on the first =
argparse action to split an argument into KEY=VALUE form and append to a dictionary.
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3',
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
""" """
@staticmethod @staticmethod
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]: def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
"""parse int/float/bool value in the string.""" """Parse int/float/bool value in the string."""
try: try:
return int(val) return int(val)
except ValueError: except ValueError:
@ -1822,7 +1823,7 @@ class DictAction(Action):
parser: ArgumentParser, parser: ArgumentParser,
namespace: Namespace, namespace: Namespace,
values: Union[str, Sequence[Any], None], values: Union[str, Sequence[Any], None],
option_string: str = None): option_string: str = None): # type: ignore
"""Parse Variables in string and add them into argparser. """Parse Variables in string and add them into argparser.
Args: Args:

View File

@ -563,7 +563,7 @@ def cast_data_device(
Tensor or list or dict: ``data`` was casted to ``device``. Tensor or list or dict: ``data`` was casted to ``device``.
""" """
if out is not None: if out is not None:
if type(data) != type(out): if type(data) is not type(out):
raise TypeError( raise TypeError(
'out should be the same type with data, but got data is ' 'out should be the same type with data, but got data is '
f'{type(data)} and out is {type(data)}') f'{type(data)} and out is {type(data)}')

View File

@ -175,11 +175,11 @@ class DumpResults(BaseMetric):
self.out_file_path = out_file_path self.out_file_path = out_file_path
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None: def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
"""transfer tensors in predictions to CPU.""" """Transfer tensors in predictions to CPU."""
self.results.extend(_to_cpu(predictions)) self.results.extend(_to_cpu(predictions))
def compute_metrics(self, results: list) -> dict: def compute_metrics(self, results: list) -> dict:
"""dump the prediction results to a pickle file.""" """Dump the prediction results to a pickle file."""
dump(results, self.out_file_path) dump(results, self.out_file_path)
print_log( print_log(
f'Results has been saved to {self.out_file_path}.', f'Results has been saved to {self.out_file_path}.',
@ -188,7 +188,7 @@ class DumpResults(BaseMetric):
def _to_cpu(data: Any) -> Any: def _to_cpu(data: Any) -> Any:
"""transfer all tensors and BaseDataElement to cpu.""" """Transfer all tensors and BaseDataElement to cpu."""
if isinstance(data, (Tensor, BaseDataElement)): if isinstance(data, (Tensor, BaseDataElement)):
return data.to('cpu') return data.to('cpu')
elif isinstance(data, list): elif isinstance(data, list):

View File

@ -233,7 +233,7 @@ class ProfilerHook(Hook):
self._export_chrome_trace(runner) self._export_chrome_trace(runner)
def after_train_iter(self, runner, batch_idx, data_batch, outputs): def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""profiler will call `step` method if it is not closed.""" """Profiler will call `step` method if it is not closed."""
if not self._closed: if not self._closed:
self.profiler.step() self.profiler.step()
if runner.iter == self.profile_times - 1 and not self.by_epoch: if runner.iter == self.profile_times - 1 and not self.by_epoch:

View File

@ -58,7 +58,7 @@ class HistoryBuffer:
self._statistics_methods.setdefault('mean', HistoryBuffer.mean) self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
def update(self, log_val: Union[int, float], count: int = 1) -> None: def update(self, log_val: Union[int, float], count: int = 1) -> None:
"""update the log history. """Update the log history.
If the length of the buffer exceeds ``self._max_length``, the oldest If the length of the buffer exceeds ``self._max_length``, the oldest
element will be removed from the buffer. element will be removed from the buffer.

View File

@ -253,17 +253,17 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
dict or list: Data in the same format as the model input. dict or list: Data in the same format as the model input.
""" """
data = self.cast_data(data) # type: ignore data = self.cast_data(data) # type: ignore
_batch_inputs = data['inputs'] _batch_inputs = data['inputs'] # type: ignore
# Process data with `pseudo_collate`. # Process data with `pseudo_collate`.
if is_seq_of(_batch_inputs, torch.Tensor): if is_seq_of(_batch_inputs, torch.Tensor):
batch_inputs = [] batch_inputs = []
for _batch_input in _batch_inputs: for _batch_input in _batch_inputs:
# channel transform # channel transform
if self._channel_conversion: if self._channel_conversion:
_batch_input = _batch_input[[2, 1, 0], ...] _batch_input = _batch_input[[2, 1, 0], ...] # type: ignore
# Convert to float after channel conversion to ensure # Convert to float after channel conversion to ensure
# efficiency # efficiency
_batch_input = _batch_input.float() _batch_input = _batch_input.float() # type: ignore
# Normalization. # Normalization.
if self._enable_normalize: if self._enable_normalize:
if self.mean.shape[0] == 3: if self.mean.shape[0] == 3:
@ -302,7 +302,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
else: else:
raise TypeError('Output of `cast_data` should be a dict of ' raise TypeError('Output of `cast_data` should be a dict of '
'list/tuple with inputs and data_samples, ' 'list/tuple with inputs and data_samples, '
f'but got {type(data)}: {data}') f'but got {type(data)}: {data}') # type: ignore
data['inputs'] = batch_inputs data['inputs'] = batch_inputs # type: ignore
data.setdefault('data_samples', None) data.setdefault('data_samples', None) # type: ignore
return data return data # type: ignore

View File

@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0):
def bias_init_with_prob(prior_prob): def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to a given probability value.""" """Initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob)) bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init return bias_init
@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor,
std: float = 1., std: float = 1.,
a: float = -2., a: float = -2.,
b: float = 2.) -> Tensor: b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated r"""Fills the input Tensor with values drawn from a truncated normal
normal distribution. The values are effectively drawn from the distribution. The values are effectively drawn from the normal distribution
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
with values outside :math:`[a, b]` redrawn until they are within :math:`[a, b]` redrawn until they are within the bounds. The method used
the bounds. The method used for generating the random values works for generating the random values works best when :math:`a \leq \text{mean}
best when :math:`a \leq \text{mean} \leq b`. \leq b`.
Modified from Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py

View File

@ -127,7 +127,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
auto_wrap_policy: Union[str, Callable, None] = None, auto_wrap_policy: Union[str, Callable, None] = None,
backward_prefetch: Union[str, BackwardPrefetch, None] = None, backward_prefetch: Union[str, BackwardPrefetch, None] = None,
mixed_precision: Union[dict, MixedPrecision, None] = None, mixed_precision: Union[dict, MixedPrecision, None] = None,
param_init_fn: Union[str, Callable[[nn.Module], None]] = None, param_init_fn: Union[str, Callable[
[nn.Module], None]] = None, # type: ignore # noqa: E501
use_orig_params: bool = True, use_orig_params: bool = True,
**kwargs, **kwargs,
): ):
@ -362,7 +363,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None, group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""copied from pytorch 2.0.1 which has fixed some bugs.""" """Copied from pytorch 2.0.1 which has fixed some bugs."""
state_dict_settings = FullyShardedDataParallel.get_state_dict_type( state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
model) model)
return FullyShardedDataParallel._optim_state_dict_impl( return FullyShardedDataParallel._optim_state_dict_impl(
@ -384,7 +385,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
state_dict_config: Optional[StateDictConfig] = None, state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None, optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> StateDictSettings: ) -> StateDictSettings:
"""copied from pytorch 2.0.1 which has fixed some bugs.""" """Copied from pytorch 2.0.1 which has fixed some bugs."""
import torch.distributed.fsdp._traversal_utils as traversal_utils import torch.distributed.fsdp._traversal_utils as traversal_utils
_state_dict_type_to_config = { _state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig, StateDictType.FULL_STATE_DICT: FullStateDictConfig,

View File

@ -123,8 +123,7 @@ class ApexOptimWrapper(OptimWrapper):
self._inner_count += 1 self._inner_count += 1
def state_dict(self) -> dict: def state_dict(self) -> dict:
"""Get the state dictionary of :attr:`optimizer` and """Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`.
:attr:`apex_amp`.
Based on the state dictionary of the optimizer, the returned state Based on the state dictionary of the optimizer, the returned state
dictionary will add a key named "apex_amp". dictionary will add a key named "apex_amp".

View File

@ -131,7 +131,7 @@ class DefaultOptimWrapperConstructor:
self._validate_cfg() self._validate_cfg()
def _validate_cfg(self) -> None: def _validate_cfg(self) -> None:
"""verify the correctness of the config.""" """Verify the correctness of the config."""
if not isinstance(self.paramwise_cfg, dict): if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, ' raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}') f'but got {type(self.paramwise_cfg)}')
@ -155,7 +155,7 @@ class DefaultOptimWrapperConstructor:
raise ValueError('base_wd should not be None') raise ValueError('base_wd should not be None')
def _is_in(self, param_group: dict, param_group_list: list) -> bool: def _is_in(self, param_group: dict, param_group_list: list) -> bool:
"""check whether the `param_group` is in the`param_group_list`""" """Check whether the `param_group` is in the`param_group_list`"""
assert is_list_of(param_group_list, dict) assert is_list_of(param_group_list, dict)
param = set(param_group['params']) param = set(param_group['params'])
param_set = set() param_set = set()

View File

@ -161,8 +161,7 @@ class OptimWrapperDict(OptimWrapper):
self.optim_wrappers[name].load_state_dict(_state_dict) self.optim_wrappers[name].load_state_dict(_state_dict)
def items(self) -> Iterator[Tuple[str, OptimWrapper]]: def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
"""A generator to get the name and corresponding """A generator to get the name and corresponding :obj:`OptimWrapper`"""
:obj:`OptimWrapper`"""
yield from self.optim_wrappers.items() yield from self.optim_wrappers.items()
def values(self) -> Iterator[OptimWrapper]: def values(self) -> Iterator[OptimWrapper]:

View File

@ -223,13 +223,13 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
@PARAM_SCHEDULERS.register_module() @PARAM_SCHEDULERS.register_module()
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler): class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
r"""Sets the learning rate of each parameter group according to the r"""Sets the learning rate of each parameter group according to the 1cycle
1cycle learning rate policy. The 1cycle policy anneals the learning learning rate policy. The 1cycle policy anneals the learning rate from an
rate from an initial learning rate to some maximum learning rate and then initial learning rate to some maximum learning rate and then from that
from that maximum learning rate to some minimum learning rate much lower maximum learning rate to some minimum learning rate much lower than the
than the initial learning rate. initial learning rate. This policy was initially described in the paper
This policy was initially described in the paper `Super-Convergence: `Super-Convergence: Very Fast Training of Neural Networks Using Large
Very Fast Training of Neural Networks Using Large Learning Rates`_. Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training. batch. `step` should be called after a batch has been used for training.

View File

@ -565,9 +565,9 @@ class ExponentialParamScheduler(_ParamScheduler):
@PARAM_SCHEDULERS.register_module() @PARAM_SCHEDULERS.register_module()
class CosineAnnealingParamScheduler(_ParamScheduler): class CosineAnnealingParamScheduler(_ParamScheduler):
r"""Set the parameter value of each parameter group using a cosine r"""Set the parameter value of each parameter group using a cosine annealing
annealing schedule, where :math:`\eta_{max}` is set to the initial value schedule, where :math:`\eta_{max}` is set to the initial value and
and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math:: .. math::
\begin{aligned} \begin{aligned}
@ -617,7 +617,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts: .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983 https://arxiv.org/abs/1608.03983
""" """ # noqa: E501
def __init__(self, def __init__(self,
optimizer: Union[Optimizer, BaseOptimWrapper], optimizer: Union[Optimizer, BaseOptimWrapper],
@ -890,13 +890,13 @@ class PolyParamScheduler(_ParamScheduler):
@PARAM_SCHEDULERS.register_module() @PARAM_SCHEDULERS.register_module()
class OneCycleParamScheduler(_ParamScheduler): class OneCycleParamScheduler(_ParamScheduler):
r"""Sets the parameters of each parameter group according to the r"""Sets the parameters of each parameter group according to the 1cycle
1cycle learning rate policy. The 1cycle policy anneals the learning learning rate policy. The 1cycle policy anneals the learning rate from an
rate from an initial learning rate to some maximum learning rate and then initial learning rate to some maximum learning rate and then from that
from that maximum learning rate to some minimum learning rate much lower maximum learning rate to some minimum learning rate much lower than the
than the initial learning rate. initial learning rate. This policy was initially described in the paper
This policy was initially described in the paper `Super-Convergence: `Super-Convergence: Very Fast Training of Neural Networks Using Large
Very Fast Training of Neural Networks Using Large Learning Rates`_. Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training. batch. `step` should be called after a batch has been used for training.

View File

@ -81,7 +81,7 @@ class DefaultScope(ManagerMixin):
@classmethod @classmethod
@contextmanager @contextmanager
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator: def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
"""overwrite the current default scope with `scope_name`""" """Overwrite the current default scope with `scope_name`"""
if scope_name is None: if scope_name is None:
yield yield
else: else:

View File

@ -332,7 +332,7 @@ class Registry:
return root return root
def import_from_location(self) -> None: def import_from_location(self) -> None:
"""import modules from the pre-defined locations in self._location.""" """Import modules from the pre-defined locations in self._location."""
if not self._imported: if not self._imported:
# Avoid circular import # Avoid circular import
from ..logging import print_log from ..logging import print_log

View File

@ -109,7 +109,7 @@ def init_default_scope(scope: str) -> None:
if current_scope.scope_name != scope: # type: ignore if current_scope.scope_name != scope: # type: ignore
print_log( print_log(
'The current default scope ' # type: ignore 'The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not "{scope}", ' f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore
'`init_default_scope` will force set the current' '`init_default_scope` will force set the current'
f'default scope to "{scope}".', f'default scope to "{scope}".',
logger='current', logger='current',

View File

@ -540,7 +540,7 @@ class FlexibleRunner:
@property @property
def hooks(self): def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks.""" """List[:obj:`Hook`]: A list of registered hooks."""
return self._hooks return self._hooks
@property @property
@ -1117,7 +1117,7 @@ class FlexibleRunner:
return '\n'.join(stage_hook_infos) return '\n'.join(stage_hook_infos)
def load_or_resume(self): def load_or_resume(self):
"""load or resume checkpoint.""" """Load or resume checkpoint."""
if self._has_loaded: if self._has_loaded:
return None return None
@ -1539,7 +1539,7 @@ class FlexibleRunner:
file_client_args: Optional[dict] = None, file_client_args: Optional[dict] = None,
save_optimizer: bool = True, save_optimizer: bool = True,
save_param_scheduler: bool = True, save_param_scheduler: bool = True,
meta: dict = None, meta: Optional[dict] = None,
by_epoch: bool = True, by_epoch: bool = True,
backend_args: Optional[dict] = None, backend_args: Optional[dict] = None,
): ):

View File

@ -309,7 +309,7 @@ class CheckpointLoader:
@classmethod @classmethod
def load_checkpoint(cls, filename, map_location=None, logger='current'): def load_checkpoint(cls, filename, map_location=None, logger='current'):
"""load checkpoint through URL scheme path. """Load checkpoint through URL scheme path.
Args: Args:
filename (str): checkpoint file name with given prefix filename (str): checkpoint file name with given prefix
@ -332,7 +332,7 @@ class CheckpointLoader:
@CheckpointLoader.register_scheme(prefixes='') @CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location): def load_from_local(filename, map_location):
"""load checkpoint by local file path. """Load checkpoint by local file path.
Args: Args:
filename (str): local checkpoint file path filename (str): local checkpoint file path
@ -353,7 +353,7 @@ def load_from_http(filename,
map_location=None, map_location=None,
model_dir=None, model_dir=None,
progress=os.isatty(0)): progress=os.isatty(0)):
"""load checkpoint through HTTP or HTTPS scheme path. In distributed """Load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0. setting, this function only download checkpoint at local rank 0.
Args: Args:
@ -386,7 +386,7 @@ def load_from_http(filename,
@CheckpointLoader.register_scheme(prefixes='pavi://') @CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None): def load_from_pavi(filename, map_location=None):
"""load checkpoint through the file path prefixed with pavi. In distributed """Load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary setting, this function download ckpt at all ranks to different temporary
directories. directories.
@ -419,7 +419,7 @@ def load_from_pavi(filename, map_location=None):
@CheckpointLoader.register_scheme( @CheckpointLoader.register_scheme(
prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://']) prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://'])
def load_from_ceph(filename, map_location=None, backend='petrel'): def load_from_ceph(filename, map_location=None, backend='petrel'):
"""load checkpoint through the file path prefixed with s3. In distributed """Load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary setting, this function download ckpt at all ranks to different temporary
directories. directories.
@ -441,7 +441,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) @CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None): def load_from_torchvision(filename, map_location=None):
"""load checkpoint through the file path prefixed with modelzoo or """Load checkpoint through the file path prefixed with modelzoo or
torchvision. torchvision.
Args: Args:
@ -467,7 +467,7 @@ def load_from_torchvision(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) @CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None): def load_from_openmmlab(filename, map_location=None):
"""load checkpoint through the file path prefixed with open-mmlab or """Load checkpoint through the file path prefixed with open-mmlab or
openmmlab. openmmlab.
Args: Args:
@ -510,7 +510,7 @@ def load_from_openmmlab(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes='mmcls://') @CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None): def load_from_mmcls(filename, map_location=None):
"""load checkpoint through the file path prefixed with mmcls. """Load checkpoint through the file path prefixed with mmcls.
Args: Args:
filename (str): checkpoint file path with mmcls prefix filename (str): checkpoint file path with mmcls prefix

View File

@ -579,7 +579,7 @@ class Runner:
@property @property
def hooks(self): def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks.""" """List[:obj:`Hook`]: A list of registered hooks."""
return self._hooks return self._hooks
@property @property
@ -720,7 +720,7 @@ class Runner:
def build_logger(self, def build_logger(self,
log_level: Union[int, str] = 'INFO', log_level: Union[int, str] = 'INFO',
log_file: str = None, log_file: Optional[str] = None,
**kwargs) -> MMLogger: **kwargs) -> MMLogger:
"""Build a global asscessable MMLogger. """Build a global asscessable MMLogger.
@ -1677,7 +1677,7 @@ class Runner:
return '\n'.join(stage_hook_infos) return '\n'.join(stage_hook_infos)
def load_or_resume(self) -> None: def load_or_resume(self) -> None:
"""load or resume checkpoint.""" """Load or resume checkpoint."""
if self._has_loaded: if self._has_loaded:
return None return None

View File

@ -387,7 +387,7 @@ class BaseDataElement:
return dict(self.metainfo_items()) return dict(self.metainfo_items())
def __setattr__(self, name: str, value: Any): def __setattr__(self, name: str, value: Any):
"""setattr is only used to set data.""" """Setattr is only used to set data."""
if name in ('_metainfo_fields', '_data_fields'): if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name): if not hasattr(self, name):
super().__setattr__(name, value) super().__setattr__(name, value)

View File

@ -135,7 +135,7 @@ class InstanceData(BaseDataElement):
""" """
def __setattr__(self, name: str, value: Sized): def __setattr__(self, name: str, value: Sized):
"""setattr is only used to set data. """Setattr is only used to set data.
The value must have the attribute of `__len__` and have the same length The value must have the attribute of `__len__` and have the same length
of `InstanceData`. of `InstanceData`.

View File

@ -57,6 +57,7 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
check_hash=False, check_hash=False,
file_name=None): file_name=None):
r"""Loads the Torch serialized object at the given URL. r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically decompressed If downloaded file is a zip file, it will be automatically decompressed
If the object is already present in `model_dir`, it's deserialized and If the object is already present in `model_dir`, it's deserialized and
returned. returned.

View File

@ -67,7 +67,7 @@ class TimeCounter:
instance.log_interval = log_interval instance.log_interval = log_interval
instance.warmup_interval = warmup_interval instance.warmup_interval = warmup_interval
instance.with_sync = with_sync instance.with_sync = with_sync # type: ignore
instance.tag = tag instance.tag = tag
instance.logger = logger instance.logger = logger
@ -127,7 +127,7 @@ class TimeCounter:
self.print_time(elapsed) self.print_time(elapsed)
def print_time(self, elapsed: Union[int, float]) -> None: def print_time(self, elapsed: Union[int, float]) -> None:
"""print times per count.""" """Print times per count."""
if self.__count >= self.warmup_interval: if self.__count >= self.warmup_interval:
self.__pure_inf_time += elapsed self.__pure_inf_time += elapsed

View File

@ -131,7 +131,7 @@ def tuple_cast(inputs, dst_type):
def is_seq_of(seq: Any, def is_seq_of(seq: Any,
expected_type: Union[Type, tuple], expected_type: Union[Type, tuple],
seq_type: Type = None) -> bool: seq_type: Optional[Type] = None) -> bool:
"""Check whether it is a sequence of some type. """Check whether it is a sequence of some type.
Args: Args:

View File

@ -69,11 +69,11 @@ def get_installed_path(package: str) -> str:
else: else:
raise e raise e
possible_path = osp.join(pkg.location, package) possible_path = osp.join(pkg.location, package) # type: ignore
if osp.exists(possible_path): if osp.exists(possible_path):
return possible_path return possible_path
else: else:
return osp.join(pkg.location, package2module(package)) return osp.join(pkg.location, package2module(package)) # type: ignore
def package2module(package: str): def package2module(package: str):

View File

@ -3,7 +3,7 @@ import sys
from collections.abc import Iterable from collections.abc import Iterable
from multiprocessing import Pool from multiprocessing import Pool
from shutil import get_terminal_size from shutil import get_terminal_size
from typing import Callable, Sequence from typing import Callable, Optional, Sequence
from .timer import Timer from .timer import Timer
@ -54,7 +54,7 @@ class ProgressBar:
self.timer = Timer() self.timer = Timer()
def update(self, num_tasks: int = 1): def update(self, num_tasks: int = 1):
"""update progressbar. """Update progressbar.
Args: Args:
num_tasks (int): Update step size. num_tasks (int): Update step size.
@ -142,8 +142,8 @@ def init_pool(process_num, initializer=None, initargs=None):
def track_parallel_progress(func: Callable, def track_parallel_progress(func: Callable,
tasks: Sequence, tasks: Sequence,
nproc: int, nproc: int,
initializer: Callable = None, initializer: Optional[Callable] = None,
initargs: tuple = None, initargs: Optional[tuple] = None,
bar_width: int = 50, bar_width: int = 50,
chunksize: int = 1, chunksize: int = 1,
skip_first: bool = False, skip_first: bool = False,

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing import Pool from multiprocessing import Pool
from typing import Callable, Iterable, Sized from typing import Callable, Iterable, Optional, Sized
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
TaskProgressColumn, TextColumn, TimeRemainingColumn) TaskProgressColumn, TextColumn, TimeRemainingColumn)
@ -47,7 +47,7 @@ def _tasks_with_index(tasks):
def track_progress_rich(func: Callable, def track_progress_rich(func: Callable,
tasks: Iterable = tuple(), tasks: Iterable = tuple(),
task_num: int = None, task_num: Optional[int] = None,
nproc: int = 1, nproc: int = 1,
chunksize: int = 1, chunksize: int = 1,
description: str = 'Processing', description: str = 'Processing',

View File

@ -161,7 +161,7 @@ class BaseVisBackend(metaclass=ABCMeta):
pass pass
def close(self) -> None: def close(self) -> None:
"""close an opened object.""" """Close an opened object."""
pass pass
@ -314,7 +314,7 @@ class LocalVisBackend(BaseVisBackend):
def _dump(self, value_dict: dict, file_path: str, def _dump(self, value_dict: dict, file_path: str,
file_format: str) -> None: file_format: str) -> None:
"""dump dict to file. """Dump dict to file.
Args: Args:
value_dict (dict) : The dict data to saved. value_dict (dict) : The dict data to saved.
@ -505,7 +505,7 @@ class WandbVisBackend(BaseVisBackend):
self._wandb.log(scalar_dict, commit=self._commit) self._wandb.log(scalar_dict, commit=self._commit)
def close(self) -> None: def close(self) -> None:
"""close an opened wandb object.""" """Close an opened wandb object."""
if hasattr(self, '_wandb'): if hasattr(self, '_wandb'):
self._wandb.join() self._wandb.join()
@ -629,7 +629,7 @@ class TensorboardVisBackend(BaseVisBackend):
self.add_scalar(key, value, step) self.add_scalar(key, value, step)
def close(self): def close(self):
"""close an opened tensorboard object.""" """Close an opened tensorboard object."""
if hasattr(self, '_tensorboard'): if hasattr(self, '_tensorboard'):
self._tensorboard.close() self._tensorboard.close()
@ -1135,7 +1135,7 @@ class NeptuneVisBackend(BaseVisBackend):
self._neptune[k].append(v, step=step) self._neptune[k].append(v, step=step)
def close(self) -> None: def close(self) -> None:
"""close an opened object.""" """Close an opened object."""
if hasattr(self, '_neptune'): if hasattr(self, '_neptune'):
self._neptune.stop() self._neptune.stop()
@ -1282,7 +1282,7 @@ class DVCLiveVisBackend(BaseVisBackend):
self.add_scalar(key, value, step, **kwargs) self.add_scalar(key, value, step, **kwargs)
def close(self) -> None: def close(self) -> None:
"""close an opened dvclive object.""" """Close an opened dvclive object."""
if not hasattr(self, '_dvclive'): if not hasattr(self, '_dvclive'):
return return

View File

@ -356,7 +356,7 @@ class Visualizer(ManagerMixin):
@master_only @master_only
def get_backend(self, name) -> 'BaseVisBackend': def get_backend(self, name) -> 'BaseVisBackend':
"""get vis backend by name. """Get vis backend by name.
Args: Args:
name (str): The name of vis backend name (str): The name of vis backend
@ -1145,7 +1145,7 @@ class Visualizer(ManagerMixin):
pass pass
def close(self) -> None: def close(self) -> None:
"""close an opened object.""" """Close an opened object."""
for vis_backend in self._vis_backends.values(): for vis_backend in self._vis_backends.values():
vis_backend.close() vis_backend.close()

View File

@ -843,8 +843,8 @@ class TestConfig:
assert cfg_dict['item4'] == 'test' assert cfg_dict['item4'] == 'test'
assert '_delete_' not in cfg_dict['item1'] assert '_delete_' not in cfg_dict['item1']
assert type(cfg_dict['item1']) == ConfigDict assert type(cfg_dict['item1']) is ConfigDict
assert type(cfg_dict['item2']) == ConfigDict assert type(cfg_dict['item2']) is ConfigDict
def _merge_intermediate_variable(self): def _merge_intermediate_variable(self):

View File

@ -300,8 +300,8 @@ except ImportError:
get_inputs.append(filepath) get_inputs.append(filepath)
with build_temporary_directory() as tmp_dir, \ with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'put', side_effect=put),\ patch.object(backend, 'put', side_effect=put), \
patch.object(backend, 'get', side_effect=get),\ patch.object(backend, 'get', side_effect=get), \
patch.object(backend, 'exists', return_value=False): patch.object(backend, 'exists', return_value=False):
tmp_dir = tmp_dir.replace('\\', '/') tmp_dir = tmp_dir.replace('\\', '/')
dst = f'{tmp_dir}/dir' dst = f'{tmp_dir}/dir'
@ -351,7 +351,7 @@ except ImportError:
with build_temporary_directory() as tmp_dir, \ with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'copyfile_from_local', patch.object(backend, 'copyfile_from_local',
side_effect=copyfile_from_local),\ side_effect=copyfile_from_local), \
patch.object(backend, 'exists', return_value=False): patch.object(backend, 'exists', return_value=False):
backend.copytree_from_local(tmp_dir, self.petrel_dir) backend.copytree_from_local(tmp_dir, self.petrel_dir)
@ -427,7 +427,7 @@ except ImportError:
def remove(filepath): def remove(filepath):
inputs.append(filepath) inputs.append(filepath)
with build_temporary_directory() as tmp_dir,\ with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'remove', side_effect=remove): patch.object(backend, 'remove', side_effect=remove):
backend.rmtree(tmp_dir) backend.rmtree(tmp_dir)

View File

@ -13,7 +13,8 @@ from mmengine.testing import assert_allclose
class TestAveragedModel(TestCase): class TestAveragedModel(TestCase):
"""Test the AveragedModel class. """Test the AveragedModel class.
Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py Some test cases are referenced from
https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
""" # noqa: E501 """ # noqa: E501
def _test_swa_model(self, net_device, avg_device): def _test_swa_model(self, net_device, avg_device):

View File

@ -102,7 +102,7 @@ class TestLRScheduler(TestCase):
rtol=0) rtol=0)
def test_scheduler_before_optim_warning(self): def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer.""" """Warns if scheduler is used before optimizer."""
def call_sch_before_optim(): def call_sch_before_optim():
scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3) scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3)

View File

@ -120,7 +120,7 @@ class TestMomentumScheduler(TestCase):
rtol=0) rtol=0)
def test_scheduler_before_optim_warning(self): def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer.""" """Warns if scheduler is used before optimizer."""
def call_sch_before_optim(): def call_sch_before_optim():
scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3) scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3)

View File

@ -127,7 +127,7 @@ class TestParameterScheduler(TestCase):
rtol=0) rtol=0)
def test_scheduler_before_optim_warning(self): def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer.""" """Warns if scheduler is used before optimizer."""
def call_sch_before_optim(): def call_sch_before_optim():
scheduler = StepParamScheduler( scheduler = StepParamScheduler(

View File

@ -251,7 +251,7 @@ def test_load_checkpoint_metadata():
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata,
*args, **kwargs): *args, **kwargs):
"""load checkpoints.""" """Load checkpoints."""
# Names of some parameters in has been changed. # Names of some parameters in has been changed.
version = local_metadata.get('version', None) version = local_metadata.get('version', None)

View File

@ -2226,7 +2226,7 @@ class TestRunner(TestCase):
@HOOKS.register_module(force=True) @HOOKS.register_module(force=True)
class TestWarmupHook(Hook): class TestWarmupHook(Hook):
"""test custom train loop.""" """Test custom train loop."""
def before_warmup_iter(self, runner, data_batch=None): def before_warmup_iter(self, runner, data_batch=None):
before_warmup_iter_results.append('before') before_warmup_iter_results.append('before')

View File

@ -64,7 +64,7 @@ class TestBaseDataElement(TestCase):
return metainfo, data return metainfo, data
def is_equal(self, x, y): def is_equal(self, x, y):
assert type(x) == type(y) assert type(x) is type(y)
if isinstance( if isinstance(
x, (int, float, str, list, tuple, dict, set, BaseDataElement)): x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
return x == y return x == y
@ -141,7 +141,7 @@ class TestBaseDataElement(TestCase):
# test new() with no arguments # test new() with no arguments
new_instances = instances.new() new_instances = instances.new()
assert type(new_instances) == type(instances) assert type(new_instances) is type(instances)
# After deepcopy, the address of new data'element will be same as # After deepcopy, the address of new data'element will be same as
# origin, but when change new data' element will not effect the origin # origin, but when change new data' element will not effect the origin
# element and will have new address # element and will have new address
@ -154,7 +154,7 @@ class TestBaseDataElement(TestCase):
# test new() with arguments # test new() with arguments
metainfo, data = self.setup_data() metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, **data) new_instances = instances.new(metainfo=metainfo, **data)
assert type(new_instances) == type(instances) assert type(new_instances) is type(instances)
assert id(new_instances.gt_instances) != id(instances.gt_instances) assert id(new_instances.gt_instances) != id(instances.gt_instances)
_, new_data = self.setup_data() _, new_data = self.setup_data()
new_instances.set_data(new_data) new_instances.set_data(new_data)
@ -168,7 +168,7 @@ class TestBaseDataElement(TestCase):
metainfo, data = self.setup_data() metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data) instances = BaseDataElement(metainfo=metainfo, **data)
new_instances = instances.clone() new_instances = instances.clone()
assert type(new_instances) == type(instances) assert type(new_instances) is type(instances)
def test_set_metainfo(self): def test_set_metainfo(self):
metainfo, _ = self.setup_data() metainfo, _ = self.setup_data()

View File

@ -45,7 +45,7 @@ class MockVisBackend:
self._add_scalars = True self._add_scalars = True
def close(self) -> None: def close(self) -> None:
"""close an opened object.""" """Close an opened object."""
self._close = True self._close = True