[Fix] Fix lint (#1598)

* [Fix] Fix lint

* [Fix] Fix lint

* Update mmengine/dist/utils.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
Mashiro 2024-11-02 22:23:51 +08:00 committed by GitHub
parent c9b59962d6
commit cc3b74b5e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 146 additions and 134 deletions

View File

@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.10.15
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: '3.10.15'
- name: Install pre-commit hook
run: |
pip install pre-commit

View File

@ -1,5 +1,9 @@
name: pr_stage_test
env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true
on:
pull_request:
paths-ignore:

View File

@ -1,7 +1,11 @@
exclude: ^tests/data/
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8
rev: 5.0.4
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
- id: flake8
- repo: https://gitee.com/openmmlab/mirrors-isort
@ -13,7 +17,7 @@ repos:
hooks:
- id: yapf
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
rev: v4.3.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
@ -55,7 +59,7 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://gitee.com/openmmlab/mirrors-mypy
rev: v0.812
rev: v1.2.0
hooks:
- id: mypy
exclude: |-
@ -63,3 +67,4 @@ repos:
^examples
| ^docs
)
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]

View File

@ -1,7 +1,11 @@
exclude: ^tests/data/
repos:
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 7.1.1
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
@ -13,7 +17,7 @@ repos:
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
@ -34,12 +38,8 @@ repos:
- mdformat-openmmlab
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter
rev: v1.3.1
rev: 06907d0
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
@ -55,7 +55,7 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
rev: v1.2.0
hooks:
- id: mypy
exclude: |-
@ -63,3 +63,4 @@ repos:
^examples
| ^docs
)
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]

View File

@ -499,7 +499,7 @@ class BaseStrategy(metaclass=ABCMeta):
'"type" and "constructor" are not in '
f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers)
return OptimWrapperDict(**optim_wrappers) # type: ignore
else:
raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}')

View File

@ -361,7 +361,7 @@ class ColossalAIStrategy(BaseStrategy):
map_location: Union[str, Callable] = 'default',
callback: Optional[Callable] = None,
) -> dict:
"""override this method since colossalai resume optimizer from filename
"""Override this method since colossalai resume optimizer from filename
directly."""
self.logger.info(f'Resume checkpoint from {filename}')

View File

@ -53,7 +53,7 @@ class DDPStrategy(SingleDeviceStrategy):
init_dist(launcher, backend, **kwargs)
def convert_model(self, model: nn.Module) -> nn.Module:
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
"""Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.
Args:

View File

@ -393,7 +393,7 @@ class Config:
def __init__(
self,
cfg_dict: dict = None,
cfg_dict: Optional[dict] = None,
cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None,
@ -1227,7 +1227,8 @@ class Config:
if base_code is not None:
base_code = ast.Expression( # type: ignore
body=base_code.value) # type: ignore
base_files = eval(compile(base_code, '', mode='eval'))
base_files = eval(compile(base_code, '',
mode='eval')) # type: ignore
else:
base_files = []
elif file_format in ('.yml', '.yaml', '.json'):
@ -1288,7 +1289,7 @@ class Config:
def _merge_a_into_b(a: dict,
b: dict,
allow_list_keys: bool = False) -> dict:
"""merge dict ``a`` into dict ``b`` (non-inplace).
"""Merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
@ -1358,22 +1359,22 @@ class Config:
@property
def filename(self) -> str:
"""get file name of config."""
"""Get file name of config."""
return self._filename
@property
def text(self) -> str:
"""get config text."""
"""Get config text."""
return self._text
@property
def env_variables(self) -> dict:
"""get used environment variables."""
"""Get used environment variables."""
return self._env_variables
@property
def pretty_text(self) -> str:
"""get formatted python config text."""
"""Get formatted python config text."""
indent = 4
@ -1727,17 +1728,17 @@ class Config:
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""Argparse action to split an argument into KEY=VALUE form on the first =
and append to a dictionary.
List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3',
or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested
brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""
@staticmethod
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
"""parse int/float/bool value in the string."""
"""Parse int/float/bool value in the string."""
try:
return int(val)
except ValueError:
@ -1822,7 +1823,7 @@ class DictAction(Action):
parser: ArgumentParser,
namespace: Namespace,
values: Union[str, Sequence[Any], None],
option_string: str = None):
option_string: str = None): # type: ignore
"""Parse Variables in string and add them into argparser.
Args:

View File

@ -563,7 +563,7 @@ def cast_data_device(
Tensor or list or dict: ``data`` was casted to ``device``.
"""
if out is not None:
if type(data) != type(out):
if type(data) is not type(out):
raise TypeError(
'out should be the same type with data, but got data is '
f'{type(data)} and out is {type(data)}')

View File

@ -175,11 +175,11 @@ class DumpResults(BaseMetric):
self.out_file_path = out_file_path
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
"""transfer tensors in predictions to CPU."""
"""Transfer tensors in predictions to CPU."""
self.results.extend(_to_cpu(predictions))
def compute_metrics(self, results: list) -> dict:
"""dump the prediction results to a pickle file."""
"""Dump the prediction results to a pickle file."""
dump(results, self.out_file_path)
print_log(
f'Results has been saved to {self.out_file_path}.',
@ -188,7 +188,7 @@ class DumpResults(BaseMetric):
def _to_cpu(data: Any) -> Any:
"""transfer all tensors and BaseDataElement to cpu."""
"""Transfer all tensors and BaseDataElement to cpu."""
if isinstance(data, (Tensor, BaseDataElement)):
return data.to('cpu')
elif isinstance(data, list):

View File

@ -233,7 +233,7 @@ class ProfilerHook(Hook):
self._export_chrome_trace(runner)
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""profiler will call `step` method if it is not closed."""
"""Profiler will call `step` method if it is not closed."""
if not self._closed:
self.profiler.step()
if runner.iter == self.profile_times - 1 and not self.by_epoch:

View File

@ -58,7 +58,7 @@ class HistoryBuffer:
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
def update(self, log_val: Union[int, float], count: int = 1) -> None:
"""update the log history.
"""Update the log history.
If the length of the buffer exceeds ``self._max_length``, the oldest
element will be removed from the buffer.

View File

@ -253,17 +253,17 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
dict or list: Data in the same format as the model input.
"""
data = self.cast_data(data) # type: ignore
_batch_inputs = data['inputs']
_batch_inputs = data['inputs'] # type: ignore
# Process data with `pseudo_collate`.
if is_seq_of(_batch_inputs, torch.Tensor):
batch_inputs = []
for _batch_input in _batch_inputs:
# channel transform
if self._channel_conversion:
_batch_input = _batch_input[[2, 1, 0], ...]
_batch_input = _batch_input[[2, 1, 0], ...] # type: ignore
# Convert to float after channel conversion to ensure
# efficiency
_batch_input = _batch_input.float()
_batch_input = _batch_input.float() # type: ignore
# Normalization.
if self._enable_normalize:
if self.mean.shape[0] == 3:
@ -302,7 +302,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
else:
raise TypeError('Output of `cast_data` should be a dict of '
'list/tuple with inputs and data_samples, '
f'but got {type(data)}: {data}')
data['inputs'] = batch_inputs
data.setdefault('data_samples', None)
return data
f'but got {type(data)}: {data}') # type: ignore
data['inputs'] = batch_inputs # type: ignore
data.setdefault('data_samples', None) # type: ignore
return data # type: ignore

View File

@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0):
def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to a given probability value."""
"""Initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
r"""Fills the input Tensor with values drawn from a truncated normal
distribution. The values are effectively drawn from the normal distribution
:math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
:math:`[a, b]` redrawn until they are within the bounds. The method used
for generating the random values works best when :math:`a \leq \text{mean}
\leq b`.
Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py

View File

@ -127,7 +127,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
auto_wrap_policy: Union[str, Callable, None] = None,
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
mixed_precision: Union[dict, MixedPrecision, None] = None,
param_init_fn: Union[str, Callable[[nn.Module], None]] = None,
param_init_fn: Union[str, Callable[
[nn.Module], None]] = None, # type: ignore # noqa: E501
use_orig_params: bool = True,
**kwargs,
):
@ -362,7 +363,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
model)
return FullyShardedDataParallel._optim_state_dict_impl(
@ -384,7 +385,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> StateDictSettings:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
import torch.distributed.fsdp._traversal_utils as traversal_utils
_state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig,

View File

@ -123,8 +123,7 @@ class ApexOptimWrapper(OptimWrapper):
self._inner_count += 1
def state_dict(self) -> dict:
"""Get the state dictionary of :attr:`optimizer` and
:attr:`apex_amp`.
"""Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`.
Based on the state dictionary of the optimizer, the returned state
dictionary will add a key named "apex_amp".

View File

@ -131,7 +131,7 @@ class DefaultOptimWrapperConstructor:
self._validate_cfg()
def _validate_cfg(self) -> None:
"""verify the correctness of the config."""
"""Verify the correctness of the config."""
if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}')
@ -155,7 +155,7 @@ class DefaultOptimWrapperConstructor:
raise ValueError('base_wd should not be None')
def _is_in(self, param_group: dict, param_group_list: list) -> bool:
"""check whether the `param_group` is in the`param_group_list`"""
"""Check whether the `param_group` is in the`param_group_list`"""
assert is_list_of(param_group_list, dict)
param = set(param_group['params'])
param_set = set()

View File

@ -161,8 +161,7 @@ class OptimWrapperDict(OptimWrapper):
self.optim_wrappers[name].load_state_dict(_state_dict)
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
"""A generator to get the name and corresponding
:obj:`OptimWrapper`"""
"""A generator to get the name and corresponding :obj:`OptimWrapper`"""
yield from self.optim_wrappers.items()
def values(self) -> Iterator[OptimWrapper]:

View File

@ -223,13 +223,13 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
@PARAM_SCHEDULERS.register_module()
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
r"""Sets the learning rate of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
r"""Sets the learning rate of each parameter group according to the 1cycle
learning rate policy. The 1cycle policy anneals the learning rate from an
initial learning rate to some maximum learning rate and then from that
maximum learning rate to some minimum learning rate much lower than the
initial learning rate. This policy was initially described in the paper
`Super-Convergence: Very Fast Training of Neural Networks Using Large
Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training.

View File

@ -565,9 +565,9 @@ class ExponentialParamScheduler(_ParamScheduler):
@PARAM_SCHEDULERS.register_module()
class CosineAnnealingParamScheduler(_ParamScheduler):
r"""Set the parameter value of each parameter group using a cosine
annealing schedule, where :math:`\eta_{max}` is set to the initial value
and :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
r"""Set the parameter value of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial value and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\begin{aligned}
@ -617,7 +617,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
""" # noqa: E501
def __init__(self,
optimizer: Union[Optimizer, BaseOptimWrapper],
@ -890,13 +890,13 @@ class PolyParamScheduler(_ParamScheduler):
@PARAM_SCHEDULERS.register_module()
class OneCycleParamScheduler(_ParamScheduler):
r"""Sets the parameters of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
r"""Sets the parameters of each parameter group according to the 1cycle
learning rate policy. The 1cycle policy anneals the learning rate from an
initial learning rate to some maximum learning rate and then from that
maximum learning rate to some minimum learning rate much lower than the
initial learning rate. This policy was initially described in the paper
`Super-Convergence: Very Fast Training of Neural Networks Using Large
Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training.

View File

@ -81,7 +81,7 @@ class DefaultScope(ManagerMixin):
@classmethod
@contextmanager
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
"""overwrite the current default scope with `scope_name`"""
"""Overwrite the current default scope with `scope_name`"""
if scope_name is None:
yield
else:

View File

@ -332,7 +332,7 @@ class Registry:
return root
def import_from_location(self) -> None:
"""import modules from the pre-defined locations in self._location."""
"""Import modules from the pre-defined locations in self._location."""
if not self._imported:
# Avoid circular import
from ..logging import print_log

View File

@ -109,7 +109,7 @@ def init_default_scope(scope: str) -> None:
if current_scope.scope_name != scope: # type: ignore
print_log(
'The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not "{scope}", '
f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore
'`init_default_scope` will force set the current'
f'default scope to "{scope}".',
logger='current',

View File

@ -540,7 +540,7 @@ class FlexibleRunner:
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
"""List[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
@ -1117,7 +1117,7 @@ class FlexibleRunner:
return '\n'.join(stage_hook_infos)
def load_or_resume(self):
"""load or resume checkpoint."""
"""Load or resume checkpoint."""
if self._has_loaded:
return None
@ -1539,7 +1539,7 @@ class FlexibleRunner:
file_client_args: Optional[dict] = None,
save_optimizer: bool = True,
save_param_scheduler: bool = True,
meta: dict = None,
meta: Optional[dict] = None,
by_epoch: bool = True,
backend_args: Optional[dict] = None,
):

View File

@ -309,7 +309,7 @@ class CheckpointLoader:
@classmethod
def load_checkpoint(cls, filename, map_location=None, logger='current'):
"""load checkpoint through URL scheme path.
"""Load checkpoint through URL scheme path.
Args:
filename (str): checkpoint file name with given prefix
@ -332,7 +332,7 @@ class CheckpointLoader:
@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location):
"""load checkpoint by local file path.
"""Load checkpoint by local file path.
Args:
filename (str): local checkpoint file path
@ -353,7 +353,7 @@ def load_from_http(filename,
map_location=None,
model_dir=None,
progress=os.isatty(0)):
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
"""Load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
@ -386,7 +386,7 @@ def load_from_http(filename,
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None):
"""load checkpoint through the file path prefixed with pavi. In distributed
"""Load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
@ -419,7 +419,7 @@ def load_from_pavi(filename, map_location=None):
@CheckpointLoader.register_scheme(
prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://'])
def load_from_ceph(filename, map_location=None, backend='petrel'):
"""load checkpoint through the file path prefixed with s3. In distributed
"""Load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
@ -441,7 +441,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None):
"""load checkpoint through the file path prefixed with modelzoo or
"""Load checkpoint through the file path prefixed with modelzoo or
torchvision.
Args:
@ -467,7 +467,7 @@ def load_from_torchvision(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None):
"""load checkpoint through the file path prefixed with open-mmlab or
"""Load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
Args:
@ -510,7 +510,7 @@ def load_from_openmmlab(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None):
"""load checkpoint through the file path prefixed with mmcls.
"""Load checkpoint through the file path prefixed with mmcls.
Args:
filename (str): checkpoint file path with mmcls prefix

View File

@ -579,7 +579,7 @@ class Runner:
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
"""List[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
@ -720,7 +720,7 @@ class Runner:
def build_logger(self,
log_level: Union[int, str] = 'INFO',
log_file: str = None,
log_file: Optional[str] = None,
**kwargs) -> MMLogger:
"""Build a global asscessable MMLogger.
@ -1677,7 +1677,7 @@ class Runner:
return '\n'.join(stage_hook_infos)
def load_or_resume(self) -> None:
"""load or resume checkpoint."""
"""Load or resume checkpoint."""
if self._has_loaded:
return None

View File

@ -387,7 +387,7 @@ class BaseDataElement:
return dict(self.metainfo_items())
def __setattr__(self, name: str, value: Any):
"""setattr is only used to set data."""
"""Setattr is only used to set data."""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)

View File

@ -135,7 +135,7 @@ class InstanceData(BaseDataElement):
"""
def __setattr__(self, name: str, value: Sized):
"""setattr is only used to set data.
"""Setattr is only used to set data.
The value must have the attribute of `__len__` and have the same length
of `InstanceData`.

View File

@ -57,6 +57,7 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
check_hash=False,
file_name=None):
r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically decompressed
If the object is already present in `model_dir`, it's deserialized and
returned.

View File

@ -67,7 +67,7 @@ class TimeCounter:
instance.log_interval = log_interval
instance.warmup_interval = warmup_interval
instance.with_sync = with_sync
instance.with_sync = with_sync # type: ignore
instance.tag = tag
instance.logger = logger
@ -127,7 +127,7 @@ class TimeCounter:
self.print_time(elapsed)
def print_time(self, elapsed: Union[int, float]) -> None:
"""print times per count."""
"""Print times per count."""
if self.__count >= self.warmup_interval:
self.__pure_inf_time += elapsed

View File

@ -131,7 +131,7 @@ def tuple_cast(inputs, dst_type):
def is_seq_of(seq: Any,
expected_type: Union[Type, tuple],
seq_type: Type = None) -> bool:
seq_type: Optional[Type] = None) -> bool:
"""Check whether it is a sequence of some type.
Args:

View File

@ -69,11 +69,11 @@ def get_installed_path(package: str) -> str:
else:
raise e
possible_path = osp.join(pkg.location, package)
possible_path = osp.join(pkg.location, package) # type: ignore
if osp.exists(possible_path):
return possible_path
else:
return osp.join(pkg.location, package2module(package))
return osp.join(pkg.location, package2module(package)) # type: ignore
def package2module(package: str):

View File

@ -3,7 +3,7 @@ import sys
from collections.abc import Iterable
from multiprocessing import Pool
from shutil import get_terminal_size
from typing import Callable, Sequence
from typing import Callable, Optional, Sequence
from .timer import Timer
@ -54,7 +54,7 @@ class ProgressBar:
self.timer = Timer()
def update(self, num_tasks: int = 1):
"""update progressbar.
"""Update progressbar.
Args:
num_tasks (int): Update step size.
@ -142,8 +142,8 @@ def init_pool(process_num, initializer=None, initargs=None):
def track_parallel_progress(func: Callable,
tasks: Sequence,
nproc: int,
initializer: Callable = None,
initargs: tuple = None,
initializer: Optional[Callable] = None,
initargs: Optional[tuple] = None,
bar_width: int = 50,
chunksize: int = 1,
skip_first: bool = False,

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing import Pool
from typing import Callable, Iterable, Sized
from typing import Callable, Iterable, Optional, Sized
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
TaskProgressColumn, TextColumn, TimeRemainingColumn)
@ -47,7 +47,7 @@ def _tasks_with_index(tasks):
def track_progress_rich(func: Callable,
tasks: Iterable = tuple(),
task_num: int = None,
task_num: Optional[int] = None,
nproc: int = 1,
chunksize: int = 1,
description: str = 'Processing',

View File

@ -161,7 +161,7 @@ class BaseVisBackend(metaclass=ABCMeta):
pass
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
pass
@ -314,7 +314,7 @@ class LocalVisBackend(BaseVisBackend):
def _dump(self, value_dict: dict, file_path: str,
file_format: str) -> None:
"""dump dict to file.
"""Dump dict to file.
Args:
value_dict (dict) : The dict data to saved.
@ -505,7 +505,7 @@ class WandbVisBackend(BaseVisBackend):
self._wandb.log(scalar_dict, commit=self._commit)
def close(self) -> None:
"""close an opened wandb object."""
"""Close an opened wandb object."""
if hasattr(self, '_wandb'):
self._wandb.join()
@ -629,7 +629,7 @@ class TensorboardVisBackend(BaseVisBackend):
self.add_scalar(key, value, step)
def close(self):
"""close an opened tensorboard object."""
"""Close an opened tensorboard object."""
if hasattr(self, '_tensorboard'):
self._tensorboard.close()
@ -1135,7 +1135,7 @@ class NeptuneVisBackend(BaseVisBackend):
self._neptune[k].append(v, step=step)
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
if hasattr(self, '_neptune'):
self._neptune.stop()
@ -1282,7 +1282,7 @@ class DVCLiveVisBackend(BaseVisBackend):
self.add_scalar(key, value, step, **kwargs)
def close(self) -> None:
"""close an opened dvclive object."""
"""Close an opened dvclive object."""
if not hasattr(self, '_dvclive'):
return

View File

@ -356,7 +356,7 @@ class Visualizer(ManagerMixin):
@master_only
def get_backend(self, name) -> 'BaseVisBackend':
"""get vis backend by name.
"""Get vis backend by name.
Args:
name (str): The name of vis backend
@ -1145,7 +1145,7 @@ class Visualizer(ManagerMixin):
pass
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
for vis_backend in self._vis_backends.values():
vis_backend.close()

View File

@ -843,8 +843,8 @@ class TestConfig:
assert cfg_dict['item4'] == 'test'
assert '_delete_' not in cfg_dict['item1']
assert type(cfg_dict['item1']) == ConfigDict
assert type(cfg_dict['item2']) == ConfigDict
assert type(cfg_dict['item1']) is ConfigDict
assert type(cfg_dict['item2']) is ConfigDict
def _merge_intermediate_variable(self):

View File

@ -300,8 +300,8 @@ except ImportError:
get_inputs.append(filepath)
with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'put', side_effect=put),\
patch.object(backend, 'get', side_effect=get),\
patch.object(backend, 'put', side_effect=put), \
patch.object(backend, 'get', side_effect=get), \
patch.object(backend, 'exists', return_value=False):
tmp_dir = tmp_dir.replace('\\', '/')
dst = f'{tmp_dir}/dir'
@ -351,7 +351,7 @@ except ImportError:
with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'copyfile_from_local',
side_effect=copyfile_from_local),\
side_effect=copyfile_from_local), \
patch.object(backend, 'exists', return_value=False):
backend.copytree_from_local(tmp_dir, self.petrel_dir)
@ -427,7 +427,7 @@ except ImportError:
def remove(filepath):
inputs.append(filepath)
with build_temporary_directory() as tmp_dir,\
with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'remove', side_effect=remove):
backend.rmtree(tmp_dir)

View File

@ -13,7 +13,8 @@ from mmengine.testing import assert_allclose
class TestAveragedModel(TestCase):
"""Test the AveragedModel class.
Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
Some test cases are referenced from
https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
""" # noqa: E501
def _test_swa_model(self, net_device, avg_device):

View File

@ -102,7 +102,7 @@ class TestLRScheduler(TestCase):
rtol=0)
def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer."""
"""Warns if scheduler is used before optimizer."""
def call_sch_before_optim():
scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3)

View File

@ -120,7 +120,7 @@ class TestMomentumScheduler(TestCase):
rtol=0)
def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer."""
"""Warns if scheduler is used before optimizer."""
def call_sch_before_optim():
scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3)

View File

@ -127,7 +127,7 @@ class TestParameterScheduler(TestCase):
rtol=0)
def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer."""
"""Warns if scheduler is used before optimizer."""
def call_sch_before_optim():
scheduler = StepParamScheduler(

View File

@ -251,7 +251,7 @@ def test_load_checkpoint_metadata():
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
*args, **kwargs):
"""load checkpoints."""
"""Load checkpoints."""
# Names of some parameters in has been changed.
version = local_metadata.get('version', None)

View File

@ -2226,7 +2226,7 @@ class TestRunner(TestCase):
@HOOKS.register_module(force=True)
class TestWarmupHook(Hook):
"""test custom train loop."""
"""Test custom train loop."""
def before_warmup_iter(self, runner, data_batch=None):
before_warmup_iter_results.append('before')

View File

@ -64,7 +64,7 @@ class TestBaseDataElement(TestCase):
return metainfo, data
def is_equal(self, x, y):
assert type(x) == type(y)
assert type(x) is type(y)
if isinstance(
x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
return x == y
@ -141,7 +141,7 @@ class TestBaseDataElement(TestCase):
# test new() with no arguments
new_instances = instances.new()
assert type(new_instances) == type(instances)
assert type(new_instances) is type(instances)
# After deepcopy, the address of new data'element will be same as
# origin, but when change new data' element will not effect the origin
# element and will have new address
@ -154,7 +154,7 @@ class TestBaseDataElement(TestCase):
# test new() with arguments
metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, **data)
assert type(new_instances) == type(instances)
assert type(new_instances) is type(instances)
assert id(new_instances.gt_instances) != id(instances.gt_instances)
_, new_data = self.setup_data()
new_instances.set_data(new_data)
@ -168,7 +168,7 @@ class TestBaseDataElement(TestCase):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
new_instances = instances.clone()
assert type(new_instances) == type(instances)
assert type(new_instances) is type(instances)
def test_set_metainfo(self):
metainfo, _ = self.setup_data()

View File

@ -45,7 +45,7 @@ class MockVisBackend:
self._add_scalars = True
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
self._close = True