[Fix] Fix type hint in file_client (#1942)

* [fix]:fix type hint in file_client and mmcv

* [fix]:fix type hint in tests files

* [fix]:fix type hint in tests files

* [fix]:fix pre-commit.yaml to igore test for mypy

* [fix]:fix pre-commit.yaml to igore test for mypy

* [fix]:fix precommit.yml

* [fix]:fix precommit.yml

* Update __init__.py

delete unused type-ignore comment
pull/1951/head
Alex Yang 2022-05-10 14:01:07 +08:00 committed by GitHub
parent fb9af9f36f
commit a848ecfdfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 45 additions and 25 deletions

View File

@ -46,6 +46,15 @@ repos:
hooks:
- id: check-copyright
args: ["mmcv", "tests", "--excludes", "mmcv/ops"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
hooks:
- id: mypy
exclude: |-
(?x)(
^test
| ^docs
)
# - repo: local
# hooks:
# - id: clang-format

View File

@ -5,9 +5,9 @@ import platform
from .registry import PLUGIN_LAYERS
if platform.system() == 'Windows':
import regex as re
import regex as re # type: ignore
else:
import re
import re # type: ignore
def infer_abbr(class_type):

View File

@ -8,7 +8,7 @@ import warnings
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Iterable, Iterator, Optional, Tuple, Union
from typing import Any, Generator, Iterator, Optional, Tuple, Union
from urllib.request import urlopen
import mmcv
@ -298,7 +298,10 @@ class PetrelBackend(BaseStorageBackend):
return '/'.join(formatted_paths)
@contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
def get_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath`` and return a temporary path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
@ -646,7 +649,9 @@ class HardDiskBackend(BaseStorageBackend):
@contextmanager
def get_local_path(
self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]:
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
yield filepath
@ -715,7 +720,8 @@ class HTTPBackend(BaseStorageBackend):
return value_buf.decode(encoding)
@contextmanager
def get_local_path(self, filepath: str) -> Iterable[str]:
def get_local_path(
self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
@ -789,15 +795,17 @@ class FileClient:
# backend appears in the collection, the singleton pattern is disabled for
# that backend, because if the singleton pattern is used, then the object
# returned will be the backend before overwriting
_overridden_backends = set()
_prefix_to_backends = {
_overridden_backends: set = set()
_prefix_to_backends: dict = {
's3': PetrelBackend,
'http': HTTPBackend,
'https': HTTPBackend,
}
_overridden_prefixes = set()
_overridden_prefixes: set = set()
_instances = {}
_instances: dict = {}
client: Any
def __new__(cls, backend=None, prefix=None, **kwargs):
if backend is None and prefix is None:
@ -1107,7 +1115,10 @@ class FileClient:
return self.client.join_path(filepath, *filepaths)
@contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
def get_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Download data from ``filepath`` and write the data to local path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It

View File

@ -5,7 +5,7 @@ try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader, Dumper
from yaml import Loader, Dumper # type: ignore
from .base import BaseFileHandler # isort:skip

View File

@ -328,4 +328,4 @@ cast_pytorch_to_onnx = {
# Global set to store the list of quantized operators in the network.
# This is currently only used in the conversion of quantized ops from PT
# -> C2 via ONNX.
_quantized_ops = set()
_quantized_ops: set = set()

View File

@ -200,7 +200,7 @@ def _process_mmcls_checkpoint(checkpoint):
class CheckpointLoader:
"""A general checkpoint loader to manage all schemes."""
_schemes = {}
_schemes: dict = {}
@classmethod
def _register_scheme(cls, prefixes, loader, force=False):

View File

@ -342,7 +342,7 @@ if (TORCH_VERSION != 'parrots'
else:
@HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook):
class Fp16OptimizerHook(OptimizerHook): # type: ignore
"""FP16 optimizer hook (mmcv's implementation).
The steps of fp16 optimizer is as follows.
@ -484,8 +484,8 @@ else:
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
@HOOKS.register_module()
class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
Fp16OptimizerHook):
class GradientCumulativeFp16OptimizerHook( # type: ignore
GradientCumulativeOptimizerHook, Fp16OptimizerHook):
"""Fp16 optimizer Hook (using mmcv implementation) implements multi-
iters gradient cumulating."""

View File

@ -22,9 +22,9 @@ from .misc import import_modules_from_strings
from .path import check_file_exist
if platform.system() == 'Windows':
import regex as re
import regex as re # type: ignore
else:
import re
import re # type: ignore
BASE_KEY = '_base_'
DELETE_KEY = '_delete_'

View File

@ -128,4 +128,4 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
'loaded in torch<1.5.')
raise error
else:
from torch.utils.model_zoo import load_url # noqa: F401
from torch.utils.model_zoo import load_url # type: ignore # noqa: F401

View File

@ -3,7 +3,7 @@ import logging
import torch.distributed as dist
logger_initialized = {}
logger_initialized: dict = {}
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):

View File

@ -103,7 +103,7 @@ _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
class SyncBatchNorm(SyncBatchNorm_):
class SyncBatchNorm(SyncBatchNorm_): # type: ignore
def _check_input_dim(self, input):
if TORCH_VERSION == 'parrots':

View File

@ -41,7 +41,7 @@ def digit_version(version_str: str, length: int = 4):
release.extend([val, 0])
elif version.is_postrelease:
release.extend([1, version.post])
release.extend([1, version.post]) # type: ignore
else:
release.extend([0, 0])
return tuple(release)

View File

@ -22,9 +22,9 @@ def parse_version_info(version_str: str, length: int = 4) -> tuple:
if len(release) < length:
release = release + [0] * (length - len(release))
if version.is_prerelease:
release.extend(list(version.pre))
release.extend(list(version.pre)) # type: ignore
elif version.is_postrelease:
release.extend(list(version.post))
release.extend(list(version.post)) # type: ignore
else:
release.extend([0, 0])
return tuple(release)