[Refactor] Refactor the import rule (#459)

* [Refactor] Refactor the import rule

* minor refinement

* add a comment
This commit is contained in:
Zaida Zhou 2022-08-23 18:58:36 +08:00 committed by GitHub
parent a9ad09bded
commit 486d8cda56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 32 additions and 34 deletions

View File

@ -2,7 +2,7 @@
from typing import Iterator, List, Optional, Sequence, Union from typing import Iterator, List, Optional, Sequence, Union
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
from ..registry.root import EVALUATOR, METRICS from mmengine.registry import EVALUATOR, METRICS
from .metric import BaseMetric from .metric import BaseMetric

View File

@ -10,8 +10,7 @@ from pathlib import Path
from typing import Any, Generator, Iterator, Optional, Tuple, Union from typing import Any, Generator, Iterator, Optional, Tuple, Union
from urllib.request import urlopen from urllib.request import urlopen
import mmengine from mmengine.utils import has_method, is_filepath, mkdir_or_exist
from mmengine.utils import has_method, is_filepath
class BaseStorageBackend(metaclass=ABCMeta): class BaseStorageBackend(metaclass=ABCMeta):
@ -523,7 +522,7 @@ class HardDiskBackend(BaseStorageBackend):
obj (bytes): Data to be written. obj (bytes): Data to be written.
filepath (str or Path): Path to write data. filepath (str or Path): Path to write data.
""" """
mmengine.mkdir_or_exist(osp.dirname(filepath)) mkdir_or_exist(osp.dirname(filepath))
with open(filepath, 'wb') as f: with open(filepath, 'wb') as f:
f.write(obj) f.write(obj)
@ -543,7 +542,7 @@ class HardDiskBackend(BaseStorageBackend):
encoding (str): The encoding format used to open the ``filepath``. encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'. Default: 'utf-8'.
""" """
mmengine.mkdir_or_exist(osp.dirname(filepath)) mkdir_or_exist(osp.dirname(filepath))
with open(filepath, 'w', encoding=encoding) as f: with open(filepath, 'w', encoding=encoding) as f:
f.write(obj) f.write(obj)

View File

@ -2,7 +2,7 @@
from io import BytesIO, StringIO from io import BytesIO, StringIO
from pathlib import Path from pathlib import Path
from ..utils import is_list_of, is_str from mmengine.utils import is_list_of, is_str
from .file_client import FileClient from .file_client import FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler

View File

@ -9,8 +9,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Union
from mmengine.dist import master_only from mmengine.dist import master_only
from mmengine.fileio import FileClient from mmengine.fileio import FileClient
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from mmengine.utils import is_seq_of from mmengine.utils import is_list_of, is_seq_of
from mmengine.utils.misc import is_list_of
from .hook import Hook from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]] DATA_BATCH = Optional[Sequence[dict]]

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence from typing import Dict, Optional, Sequence
from ..registry import HOOKS from mmengine.registry import HOOKS
from ..utils import get_git_hash from mmengine.utils import get_git_hash
from ..version import __version__ from mmengine.version import __version__
from .hook import Hook from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]] DATA_BATCH = Optional[Sequence[dict]]

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmengine import dist from mmengine.dist import all_reduce_params, is_distributed
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from .hook import Hook from .hook import Hook
@ -12,7 +12,7 @@ class SyncBuffersHook(Hook):
priority = 'NORMAL' priority = 'NORMAL'
def __init__(self) -> None: def __init__(self) -> None:
self.distributed = dist.is_distributed() self.distributed = is_distributed()
def after_train_epoch(self, runner) -> None: def after_train_epoch(self, runner) -> None:
"""All-reduce model buffers at the end of each epoch. """All-reduce model buffers at the end of each epoch.
@ -21,4 +21,4 @@ class SyncBuffersHook(Hook):
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if self.distributed: if self.distributed:
dist.all_reduce_params(runner.model.buffers(), op='mean') all_reduce_params(runner.model.buffers(), op='mean')

View File

@ -4,9 +4,7 @@ import datetime
from collections import OrderedDict from collections import OrderedDict
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch from mmengine.device import get_max_cuda_memory, is_cuda_available
from mmengine.device import get_max_cuda_memory
from mmengine.registry import LOG_PROCESSORS from mmengine.registry import LOG_PROCESSORS
@ -173,7 +171,7 @@ class LogProcessor:
log_tag.pop('data_time') log_tag.pop('data_time')
# If cuda is available, the max memory occupied should be calculated. # If cuda is available, the max memory occupied should be calculated.
if torch.cuda.is_available(): if is_cuda_available():
log_str += f'memory: {self._get_max_memory(runner)} ' log_str += f'memory: {self._get_max_memory(runner)} '
# Loop left keys to fill `log_str`. # Loop left keys to fill `log_str`.
if mode in ('train', 'val'): if mode in ('train', 'val'):

View File

@ -7,7 +7,7 @@ from typing import Optional, Union
from termcolor import colored from termcolor import colored
from mmengine import dist from mmengine.dist import get_rank
from mmengine.utils import ManagerMixin from mmengine.utils import ManagerMixin
from mmengine.utils.manager import _accquire_lock, _release_lock from mmengine.utils.manager import _accquire_lock, _release_lock
@ -152,7 +152,7 @@ class MMLogger(Logger, ManagerMixin):
Logger.__init__(self, logger_name) Logger.__init__(self, logger_name)
ManagerMixin.__init__(self, name) ManagerMixin.__init__(self, name)
# Get rank in DDP mode. # Get rank in DDP mode.
rank = dist.get_rank() rank = get_rank()
# Config stream_handler. If `rank != 0`. stream_handler can only # Config stream_handler. If `rank != 0`. stream_handler can only
# export ERROR logs. # export ERROR logs.

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
from mmengine.device.utils import get_device from mmengine.device import get_device
from mmengine.optim import OptimWrapperDict from mmengine.optim import OptimWrapperDict
from mmengine.registry import MODEL_WRAPPERS from mmengine.registry import MODEL_WRAPPERS
from .distributed import MMDistributedDataParallel from .distributed import MMDistributedDataParallel

View File

@ -5,9 +5,9 @@ from typing import Optional
import torch import torch
from ..device import get_device from mmengine.device import get_device
from ..logging import print_log from mmengine.logging import print_log
from ..utils import TORCH_VERSION, digit_version from mmengine.utils import TORCH_VERSION, digit_version
@contextmanager @contextmanager

View File

@ -33,10 +33,10 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel,
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler, from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper) build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS,
PARAM_SCHEDULERS, RUNNERS, VISUALIZERS, OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
DefaultScope, count_registered_modules) VISUALIZERS, DefaultScope,
from mmengine.registry.root import LOG_PROCESSORS count_registered_modules)
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version, from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
get_git_hash, is_seq_of, revert_sync_batchnorm, get_git_hash, is_seq_of, revert_sync_batchnorm,
set_multi_processing) set_multi_processing)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from mmengine.utils.misc import is_list_of from mmengine.utils import is_list_of
def calc_dynamic_intervals( def calc_dynamic_intervals(

View File

@ -8,14 +8,16 @@ from unittest.mock import patch
import pytest import pytest
from mmengine import MMLogger, print_log from mmengine.logging import MMLogger, print_log
class TestLogger: class TestLogger:
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}' stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}' file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
@patch('mmengine.dist.get_rank', lambda: 0) # Since `get_rank` has been imported in logger.py, it needs to mock
# `logger.get_rank`
@patch('mmengine.logging.logger.get_rank', lambda: 0)
def test_init_rank0(self, tmp_path): def test_init_rank0(self, tmp_path):
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO') logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
assert logger.name == 'mmengine' assert logger.name == 'mmengine'
@ -45,7 +47,7 @@ class TestLogger:
assert logger.instance_name == 'rank0.pkg3' assert logger.instance_name == 'rank0.pkg3'
logging.shutdown() logging.shutdown()
@patch('mmengine.dist.get_rank', lambda: 1) @patch('mmengine.logging.logger.get_rank', lambda: 1)
def test_init_rank1(self, tmp_path): def test_init_rank1(self, tmp_path):
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`. # If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
tmp_file = tmp_path / 'tmp_file.log' tmp_file = tmp_path / 'tmp_file.log'

View File

@ -6,7 +6,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from mmengine import HistoryBuffer, MessageHub from mmengine.logging import HistoryBuffer, MessageHub
class NoDeepCopy: class NoDeepCopy:

View File

@ -11,8 +11,8 @@ from torch.cuda.amp import GradScaler
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD, Adam, Optimizer from torch.optim import SGD, Adam, Optimizer
from mmengine import MessageHub, MMLogger
from mmengine.dist import all_gather from mmengine.dist import all_gather
from mmengine.logging import MessageHub, MMLogger
from mmengine.optim import AmpOptimWrapper, OptimWrapper from mmengine.optim import AmpOptimWrapper, OptimWrapper
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase from mmengine.testing._internal import MultiProcessTestCase