diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b2686b7ca..0083dc27e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,6 +46,15 @@ repos: hooks: - id: check-copyright args: ["mmcv", "tests", "--excludes", "mmcv/ops"] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.812 + hooks: + - id: mypy + exclude: |- + (?x)( + ^test + | ^docs + ) # - repo: local # hooks: # - id: clang-format diff --git a/mmcv/cnn/bricks/plugin.py b/mmcv/cnn/bricks/plugin.py index 009f7529b..6aa13f439 100644 --- a/mmcv/cnn/bricks/plugin.py +++ b/mmcv/cnn/bricks/plugin.py @@ -5,9 +5,9 @@ import platform from .registry import PLUGIN_LAYERS if platform.system() == 'Windows': - import regex as re + import regex as re # type: ignore else: - import re + import re # type: ignore def infer_abbr(class_type): diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index e7fd7cdfa..aed74f054 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -8,7 +8,7 @@ import warnings from abc import ABCMeta, abstractmethod from contextlib import contextmanager from pathlib import Path -from typing import Iterable, Iterator, Optional, Tuple, Union +from typing import Any, Generator, Iterator, Optional, Tuple, Union from urllib.request import urlopen import mmcv @@ -298,7 +298,10 @@ class PetrelBackend(BaseStorageBackend): return '/'.join(formatted_paths) @contextmanager - def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: + def get_local_path( + self, + filepath: Union[str, + Path]) -> Generator[Union[str, Path], None, None]: """Download a file from ``filepath`` and return a temporary path. ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It @@ -646,7 +649,9 @@ class HardDiskBackend(BaseStorageBackend): @contextmanager def get_local_path( - self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]: + self, + filepath: Union[str, + Path]) -> Generator[Union[str, Path], None, None]: """Only for unified API and do nothing.""" yield filepath @@ -715,7 +720,8 @@ class HTTPBackend(BaseStorageBackend): return value_buf.decode(encoding) @contextmanager - def get_local_path(self, filepath: str) -> Iterable[str]: + def get_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: """Download a file from ``filepath``. ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It @@ -789,15 +795,17 @@ class FileClient: # backend appears in the collection, the singleton pattern is disabled for # that backend, because if the singleton pattern is used, then the object # returned will be the backend before overwriting - _overridden_backends = set() - _prefix_to_backends = { + _overridden_backends: set = set() + _prefix_to_backends: dict = { 's3': PetrelBackend, 'http': HTTPBackend, 'https': HTTPBackend, } - _overridden_prefixes = set() + _overridden_prefixes: set = set() - _instances = {} + _instances: dict = {} + + client: Any def __new__(cls, backend=None, prefix=None, **kwargs): if backend is None and prefix is None: @@ -1107,7 +1115,10 @@ class FileClient: return self.client.join_path(filepath, *filepaths) @contextmanager - def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: + def get_local_path( + self, + filepath: Union[str, + Path]) -> Generator[Union[str, Path], None, None]: """Download data from ``filepath`` and write the data to local path. ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It diff --git a/mmcv/fileio/handlers/yaml_handler.py b/mmcv/fileio/handlers/yaml_handler.py index 60911e7e6..1c1b07794 100644 --- a/mmcv/fileio/handlers/yaml_handler.py +++ b/mmcv/fileio/handlers/yaml_handler.py @@ -5,7 +5,7 @@ try: from yaml import CDumper as Dumper from yaml import CLoader as Loader except ImportError: - from yaml import Loader, Dumper + from yaml import Loader, Dumper # type: ignore from .base import BaseFileHandler # isort:skip diff --git a/mmcv/onnx/onnx_utils/symbolic_helper.py b/mmcv/onnx/onnx_utils/symbolic_helper.py index a9a31eb4a..ded275672 100644 --- a/mmcv/onnx/onnx_utils/symbolic_helper.py +++ b/mmcv/onnx/onnx_utils/symbolic_helper.py @@ -328,4 +328,4 @@ cast_pytorch_to_onnx = { # Global set to store the list of quantized operators in the network. # This is currently only used in the conversion of quantized ops from PT # -> C2 via ONNX. -_quantized_ops = set() +_quantized_ops: set = set() diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 835ee725a..31e395edc 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -200,7 +200,7 @@ def _process_mmcls_checkpoint(checkpoint): class CheckpointLoader: """A general checkpoint loader to manage all schemes.""" - _schemes = {} + _schemes: dict = {} @classmethod def _register_scheme(cls, prefixes, loader, force=False): diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py index 12c885183..b34cc9ddc 100644 --- a/mmcv/runner/hooks/optimizer.py +++ b/mmcv/runner/hooks/optimizer.py @@ -342,7 +342,7 @@ if (TORCH_VERSION != 'parrots' else: @HOOKS.register_module() - class Fp16OptimizerHook(OptimizerHook): + class Fp16OptimizerHook(OptimizerHook): # type: ignore """FP16 optimizer hook (mmcv's implementation). The steps of fp16 optimizer is as follows. @@ -484,8 +484,8 @@ else: 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() @HOOKS.register_module() - class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, - Fp16OptimizerHook): + class GradientCumulativeFp16OptimizerHook( # type: ignore + GradientCumulativeOptimizerHook, Fp16OptimizerHook): """Fp16 optimizer Hook (using mmcv implementation) implements multi- iters gradient cumulating.""" diff --git a/mmcv/utils/config.py b/mmcv/utils/config.py index 8efbc24e2..541ddf01a 100644 --- a/mmcv/utils/config.py +++ b/mmcv/utils/config.py @@ -22,9 +22,9 @@ from .misc import import_modules_from_strings from .path import check_file_exist if platform.system() == 'Windows': - import regex as re + import regex as re # type: ignore else: - import re + import re # type: ignore BASE_KEY = '_base_' DELETE_KEY = '_delete_' diff --git a/mmcv/utils/hub.py b/mmcv/utils/hub.py index 12fbff2ee..a9cbbc95b 100644 --- a/mmcv/utils/hub.py +++ b/mmcv/utils/hub.py @@ -128,4 +128,4 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version( 'loaded in torch<1.5.') raise error else: - from torch.utils.model_zoo import load_url # noqa: F401 + from torch.utils.model_zoo import load_url # type: ignore # noqa: F401 diff --git a/mmcv/utils/logging.py b/mmcv/utils/logging.py index c4c7025f0..5a90aac8b 100644 --- a/mmcv/utils/logging.py +++ b/mmcv/utils/logging.py @@ -3,7 +3,7 @@ import logging import torch.distributed as dist -logger_initialized = {} +logger_initialized: dict = {} def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): diff --git a/mmcv/utils/parrots_wrapper.py b/mmcv/utils/parrots_wrapper.py index 7e657b561..cf2c7e5ce 100644 --- a/mmcv/utils/parrots_wrapper.py +++ b/mmcv/utils/parrots_wrapper.py @@ -103,7 +103,7 @@ _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() -class SyncBatchNorm(SyncBatchNorm_): +class SyncBatchNorm(SyncBatchNorm_): # type: ignore def _check_input_dim(self, input): if TORCH_VERSION == 'parrots': diff --git a/mmcv/utils/version_utils.py b/mmcv/utils/version_utils.py index 963c45a2e..77c41f608 100644 --- a/mmcv/utils/version_utils.py +++ b/mmcv/utils/version_utils.py @@ -41,7 +41,7 @@ def digit_version(version_str: str, length: int = 4): release.extend([val, 0]) elif version.is_postrelease: - release.extend([1, version.post]) + release.extend([1, version.post]) # type: ignore else: release.extend([0, 0]) return tuple(release) diff --git a/mmcv/version.py b/mmcv/version.py index a97ffc2dd..000c2619a 100644 --- a/mmcv/version.py +++ b/mmcv/version.py @@ -22,9 +22,9 @@ def parse_version_info(version_str: str, length: int = 4) -> tuple: if len(release) < length: release = release + [0] * (length - len(release)) if version.is_prerelease: - release.extend(list(version.pre)) + release.extend(list(version.pre)) # type: ignore elif version.is_postrelease: - release.extend(list(version.post)) + release.extend(list(version.post)) # type: ignore else: release.extend([0, 0]) return tuple(release)