mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor the import rule (#459)
* [Refactor] Refactor the import rule * minor refinement * add a comment
This commit is contained in:
parent
a9ad09bded
commit
486d8cda56
@ -2,7 +2,7 @@
|
||||
from typing import Iterator, List, Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from ..registry.root import EVALUATOR, METRICS
|
||||
from mmengine.registry import EVALUATOR, METRICS
|
||||
from .metric import BaseMetric
|
||||
|
||||
|
||||
|
@ -10,8 +10,7 @@ from pathlib import Path
|
||||
from typing import Any, Generator, Iterator, Optional, Tuple, Union
|
||||
from urllib.request import urlopen
|
||||
|
||||
import mmengine
|
||||
from mmengine.utils import has_method, is_filepath
|
||||
from mmengine.utils import has_method, is_filepath, mkdir_or_exist
|
||||
|
||||
|
||||
class BaseStorageBackend(metaclass=ABCMeta):
|
||||
@ -523,7 +522,7 @@ class HardDiskBackend(BaseStorageBackend):
|
||||
obj (bytes): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
"""
|
||||
mmengine.mkdir_or_exist(osp.dirname(filepath))
|
||||
mkdir_or_exist(osp.dirname(filepath))
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(obj)
|
||||
|
||||
@ -543,7 +542,7 @@ class HardDiskBackend(BaseStorageBackend):
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
"""
|
||||
mmengine.mkdir_or_exist(osp.dirname(filepath))
|
||||
mkdir_or_exist(osp.dirname(filepath))
|
||||
with open(filepath, 'w', encoding=encoding) as f:
|
||||
f.write(obj)
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from ..utils import is_list_of, is_str
|
||||
from mmengine.utils import is_list_of, is_str
|
||||
from .file_client import FileClient
|
||||
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
|
||||
|
||||
|
@ -9,8 +9,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils import is_seq_of
|
||||
from mmengine.utils.misc import is_list_of
|
||||
from mmengine.utils import is_list_of, is_seq_of
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
@ -1,9 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
from ..registry import HOOKS
|
||||
from ..utils import get_git_hash
|
||||
from ..version import __version__
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils import get_git_hash
|
||||
from mmengine.version import __version__
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine import dist
|
||||
from mmengine.dist import all_reduce_params, is_distributed
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
@ -12,7 +12,7 @@ class SyncBuffersHook(Hook):
|
||||
priority = 'NORMAL'
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.distributed = dist.is_distributed()
|
||||
self.distributed = is_distributed()
|
||||
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""All-reduce model buffers at the end of each epoch.
|
||||
@ -21,4 +21,4 @@ class SyncBuffersHook(Hook):
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.distributed:
|
||||
dist.all_reduce_params(runner.model.buffers(), op='mean')
|
||||
all_reduce_params(runner.model.buffers(), op='mean')
|
||||
|
@ -4,9 +4,7 @@ import datetime
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.device import get_max_cuda_memory
|
||||
from mmengine.device import get_max_cuda_memory, is_cuda_available
|
||||
from mmengine.registry import LOG_PROCESSORS
|
||||
|
||||
|
||||
@ -173,7 +171,7 @@ class LogProcessor:
|
||||
log_tag.pop('data_time')
|
||||
|
||||
# If cuda is available, the max memory occupied should be calculated.
|
||||
if torch.cuda.is_available():
|
||||
if is_cuda_available():
|
||||
log_str += f'memory: {self._get_max_memory(runner)} '
|
||||
# Loop left keys to fill `log_str`.
|
||||
if mode in ('train', 'val'):
|
||||
|
@ -7,7 +7,7 @@ from typing import Optional, Union
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from mmengine import dist
|
||||
from mmengine.dist import get_rank
|
||||
from mmengine.utils import ManagerMixin
|
||||
from mmengine.utils.manager import _accquire_lock, _release_lock
|
||||
|
||||
@ -152,7 +152,7 @@ class MMLogger(Logger, ManagerMixin):
|
||||
Logger.__init__(self, logger_name)
|
||||
ManagerMixin.__init__(self, name)
|
||||
# Get rank in DDP mode.
|
||||
rank = dist.get_rank()
|
||||
rank = get_rank()
|
||||
|
||||
# Config stream_handler. If `rank != 0`. stream_handler can only
|
||||
# export ERROR logs.
|
||||
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.device.utils import get_device
|
||||
from mmengine.device import get_device
|
||||
from mmengine.optim import OptimWrapperDict
|
||||
from mmengine.registry import MODEL_WRAPPERS
|
||||
from .distributed import MMDistributedDataParallel
|
||||
|
@ -5,9 +5,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..device import get_device
|
||||
from ..logging import print_log
|
||||
from ..utils import TORCH_VERSION, digit_version
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -33,10 +33,10 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
||||
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
||||
build_optim_wrapper)
|
||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
||||
LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS,
|
||||
PARAM_SCHEDULERS, RUNNERS, VISUALIZERS,
|
||||
DefaultScope, count_registered_modules)
|
||||
from mmengine.registry.root import LOG_PROCESSORS
|
||||
LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS,
|
||||
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
|
||||
VISUALIZERS, DefaultScope,
|
||||
count_registered_modules)
|
||||
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
|
||||
get_git_hash, is_seq_of, revert_sync_batchnorm,
|
||||
set_multi_processing)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from mmengine.utils.misc import is_list_of
|
||||
from mmengine.utils import is_list_of
|
||||
|
||||
|
||||
def calc_dynamic_intervals(
|
||||
|
@ -8,14 +8,16 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mmengine import MMLogger, print_log
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
|
||||
|
||||
class TestLogger:
|
||||
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
||||
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
||||
|
||||
@patch('mmengine.dist.get_rank', lambda: 0)
|
||||
# Since `get_rank` has been imported in logger.py, it needs to mock
|
||||
# `logger.get_rank`
|
||||
@patch('mmengine.logging.logger.get_rank', lambda: 0)
|
||||
def test_init_rank0(self, tmp_path):
|
||||
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
|
||||
assert logger.name == 'mmengine'
|
||||
@ -45,7 +47,7 @@ class TestLogger:
|
||||
assert logger.instance_name == 'rank0.pkg3'
|
||||
logging.shutdown()
|
||||
|
||||
@patch('mmengine.dist.get_rank', lambda: 1)
|
||||
@patch('mmengine.logging.logger.get_rank', lambda: 1)
|
||||
def test_init_rank1(self, tmp_path):
|
||||
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
|
||||
tmp_file = tmp_path / 'tmp_file.log'
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine import HistoryBuffer, MessageHub
|
||||
from mmengine.logging import HistoryBuffer, MessageHub
|
||||
|
||||
|
||||
class NoDeepCopy:
|
||||
|
@ -11,8 +11,8 @@ from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.optim import SGD, Adam, Optimizer
|
||||
|
||||
from mmengine import MessageHub, MMLogger
|
||||
from mmengine.dist import all_gather
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.optim import AmpOptimWrapper, OptimWrapper
|
||||
from mmengine.testing import assert_allclose
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
|
Loading…
x
Reference in New Issue
Block a user