[Refactor] Refactor the import rule (#459)

* [Refactor] Refactor the import rule

* minor refinement

* add a comment
This commit is contained in:
Zaida Zhou 2022-08-23 18:58:36 +08:00 committed by GitHub
parent a9ad09bded
commit 486d8cda56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 32 additions and 34 deletions

View File

@ -2,7 +2,7 @@
from typing import Iterator, List, Optional, Sequence, Union
from mmengine.data import BaseDataElement
from ..registry.root import EVALUATOR, METRICS
from mmengine.registry import EVALUATOR, METRICS
from .metric import BaseMetric

View File

@ -10,8 +10,7 @@ from pathlib import Path
from typing import Any, Generator, Iterator, Optional, Tuple, Union
from urllib.request import urlopen
import mmengine
from mmengine.utils import has_method, is_filepath
from mmengine.utils import has_method, is_filepath, mkdir_or_exist
class BaseStorageBackend(metaclass=ABCMeta):
@ -523,7 +522,7 @@ class HardDiskBackend(BaseStorageBackend):
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
mmengine.mkdir_or_exist(osp.dirname(filepath))
mkdir_or_exist(osp.dirname(filepath))
with open(filepath, 'wb') as f:
f.write(obj)
@ -543,7 +542,7 @@ class HardDiskBackend(BaseStorageBackend):
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
mmengine.mkdir_or_exist(osp.dirname(filepath))
mkdir_or_exist(osp.dirname(filepath))
with open(filepath, 'w', encoding=encoding) as f:
f.write(obj)

View File

@ -2,7 +2,7 @@
from io import BytesIO, StringIO
from pathlib import Path
from ..utils import is_list_of, is_str
from mmengine.utils import is_list_of, is_str
from .file_client import FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler

View File

@ -9,8 +9,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Union
from mmengine.dist import master_only
from mmengine.fileio import FileClient
from mmengine.registry import HOOKS
from mmengine.utils import is_seq_of
from mmengine.utils.misc import is_list_of
from mmengine.utils import is_list_of, is_seq_of
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence
from ..registry import HOOKS
from ..utils import get_git_hash
from ..version import __version__
from mmengine.registry import HOOKS
from mmengine.utils import get_git_hash
from mmengine.version import __version__
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import dist
from mmengine.dist import all_reduce_params, is_distributed
from mmengine.registry import HOOKS
from .hook import Hook
@ -12,7 +12,7 @@ class SyncBuffersHook(Hook):
priority = 'NORMAL'
def __init__(self) -> None:
self.distributed = dist.is_distributed()
self.distributed = is_distributed()
def after_train_epoch(self, runner) -> None:
"""All-reduce model buffers at the end of each epoch.
@ -21,4 +21,4 @@ class SyncBuffersHook(Hook):
runner (Runner): The runner of the training process.
"""
if self.distributed:
dist.all_reduce_params(runner.model.buffers(), op='mean')
all_reduce_params(runner.model.buffers(), op='mean')

View File

@ -4,9 +4,7 @@ import datetime
from collections import OrderedDict
from typing import List, Optional, Tuple
import torch
from mmengine.device import get_max_cuda_memory
from mmengine.device import get_max_cuda_memory, is_cuda_available
from mmengine.registry import LOG_PROCESSORS
@ -173,7 +171,7 @@ class LogProcessor:
log_tag.pop('data_time')
# If cuda is available, the max memory occupied should be calculated.
if torch.cuda.is_available():
if is_cuda_available():
log_str += f'memory: {self._get_max_memory(runner)} '
# Loop left keys to fill `log_str`.
if mode in ('train', 'val'):

View File

@ -7,7 +7,7 @@ from typing import Optional, Union
from termcolor import colored
from mmengine import dist
from mmengine.dist import get_rank
from mmengine.utils import ManagerMixin
from mmengine.utils.manager import _accquire_lock, _release_lock
@ -152,7 +152,7 @@ class MMLogger(Logger, ManagerMixin):
Logger.__init__(self, logger_name)
ManagerMixin.__init__(self, name)
# Get rank in DDP mode.
rank = dist.get_rank()
rank = get_rank()
# Config stream_handler. If `rank != 0`. stream_handler can only
# export ERROR logs.

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel
from mmengine.data import BaseDataElement
from mmengine.device.utils import get_device
from mmengine.device import get_device
from mmengine.optim import OptimWrapperDict
from mmengine.registry import MODEL_WRAPPERS
from .distributed import MMDistributedDataParallel

View File

@ -5,9 +5,9 @@ from typing import Optional
import torch
from ..device import get_device
from ..logging import print_log
from ..utils import TORCH_VERSION, digit_version
from mmengine.device import get_device
from mmengine.logging import print_log
from mmengine.utils import TORCH_VERSION, digit_version
@contextmanager

View File

@ -33,10 +33,10 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel,
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, RUNNERS, VISUALIZERS,
DefaultScope, count_registered_modules)
from mmengine.registry.root import LOG_PROCESSORS
LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
VISUALIZERS, DefaultScope,
count_registered_modules)
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
get_git_hash, is_seq_of, revert_sync_batchnorm,
set_multi_processing)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
from mmengine.utils.misc import is_list_of
from mmengine.utils import is_list_of
def calc_dynamic_intervals(

View File

@ -8,14 +8,16 @@ from unittest.mock import patch
import pytest
from mmengine import MMLogger, print_log
from mmengine.logging import MMLogger, print_log
class TestLogger:
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
@patch('mmengine.dist.get_rank', lambda: 0)
# Since `get_rank` has been imported in logger.py, it needs to mock
# `logger.get_rank`
@patch('mmengine.logging.logger.get_rank', lambda: 0)
def test_init_rank0(self, tmp_path):
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
assert logger.name == 'mmengine'
@ -45,7 +47,7 @@ class TestLogger:
assert logger.instance_name == 'rank0.pkg3'
logging.shutdown()
@patch('mmengine.dist.get_rank', lambda: 1)
@patch('mmengine.logging.logger.get_rank', lambda: 1)
def test_init_rank1(self, tmp_path):
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
tmp_file = tmp_path / 'tmp_file.log'

View File

@ -6,7 +6,7 @@ import numpy as np
import pytest
import torch
from mmengine import HistoryBuffer, MessageHub
from mmengine.logging import HistoryBuffer, MessageHub
class NoDeepCopy:

View File

@ -11,8 +11,8 @@ from torch.cuda.amp import GradScaler
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD, Adam, Optimizer
from mmengine import MessageHub, MMLogger
from mmengine.dist import all_gather
from mmengine.logging import MessageHub, MMLogger
from mmengine.optim import AmpOptimWrapper, OptimWrapper
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase