mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor the import rule (#459)
* [Refactor] Refactor the import rule * minor refinement * add a comment
This commit is contained in:
parent
a9ad09bded
commit
486d8cda56
@ -2,7 +2,7 @@
|
|||||||
from typing import Iterator, List, Optional, Sequence, Union
|
from typing import Iterator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from mmengine.data import BaseDataElement
|
from mmengine.data import BaseDataElement
|
||||||
from ..registry.root import EVALUATOR, METRICS
|
from mmengine.registry import EVALUATOR, METRICS
|
||||||
from .metric import BaseMetric
|
from .metric import BaseMetric
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,8 +10,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Generator, Iterator, Optional, Tuple, Union
|
from typing import Any, Generator, Iterator, Optional, Tuple, Union
|
||||||
from urllib.request import urlopen
|
from urllib.request import urlopen
|
||||||
|
|
||||||
import mmengine
|
from mmengine.utils import has_method, is_filepath, mkdir_or_exist
|
||||||
from mmengine.utils import has_method, is_filepath
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStorageBackend(metaclass=ABCMeta):
|
class BaseStorageBackend(metaclass=ABCMeta):
|
||||||
@ -523,7 +522,7 @@ class HardDiskBackend(BaseStorageBackend):
|
|||||||
obj (bytes): Data to be written.
|
obj (bytes): Data to be written.
|
||||||
filepath (str or Path): Path to write data.
|
filepath (str or Path): Path to write data.
|
||||||
"""
|
"""
|
||||||
mmengine.mkdir_or_exist(osp.dirname(filepath))
|
mkdir_or_exist(osp.dirname(filepath))
|
||||||
with open(filepath, 'wb') as f:
|
with open(filepath, 'wb') as f:
|
||||||
f.write(obj)
|
f.write(obj)
|
||||||
|
|
||||||
@ -543,7 +542,7 @@ class HardDiskBackend(BaseStorageBackend):
|
|||||||
encoding (str): The encoding format used to open the ``filepath``.
|
encoding (str): The encoding format used to open the ``filepath``.
|
||||||
Default: 'utf-8'.
|
Default: 'utf-8'.
|
||||||
"""
|
"""
|
||||||
mmengine.mkdir_or_exist(osp.dirname(filepath))
|
mkdir_or_exist(osp.dirname(filepath))
|
||||||
with open(filepath, 'w', encoding=encoding) as f:
|
with open(filepath, 'w', encoding=encoding) as f:
|
||||||
f.write(obj)
|
f.write(obj)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from io import BytesIO, StringIO
|
from io import BytesIO, StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ..utils import is_list_of, is_str
|
from mmengine.utils import is_list_of, is_str
|
||||||
from .file_client import FileClient
|
from .file_client import FileClient
|
||||||
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
|
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
|
||||||
|
|
||||||
|
@ -9,8 +9,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Union
|
|||||||
from mmengine.dist import master_only
|
from mmengine.dist import master_only
|
||||||
from mmengine.fileio import FileClient
|
from mmengine.fileio import FileClient
|
||||||
from mmengine.registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
from mmengine.utils import is_seq_of
|
from mmengine.utils import is_list_of, is_seq_of
|
||||||
from mmengine.utils.misc import is_list_of
|
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
DATA_BATCH = Optional[Sequence[dict]]
|
DATA_BATCH = Optional[Sequence[dict]]
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Dict, Optional, Sequence
|
from typing import Dict, Optional, Sequence
|
||||||
|
|
||||||
from ..registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
from ..utils import get_git_hash
|
from mmengine.utils import get_git_hash
|
||||||
from ..version import __version__
|
from mmengine.version import __version__
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
DATA_BATCH = Optional[Sequence[dict]]
|
DATA_BATCH = Optional[Sequence[dict]]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmengine import dist
|
from mmengine.dist import all_reduce_params, is_distributed
|
||||||
from mmengine.registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ class SyncBuffersHook(Hook):
|
|||||||
priority = 'NORMAL'
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.distributed = dist.is_distributed()
|
self.distributed = is_distributed()
|
||||||
|
|
||||||
def after_train_epoch(self, runner) -> None:
|
def after_train_epoch(self, runner) -> None:
|
||||||
"""All-reduce model buffers at the end of each epoch.
|
"""All-reduce model buffers at the end of each epoch.
|
||||||
@ -21,4 +21,4 @@ class SyncBuffersHook(Hook):
|
|||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
if self.distributed:
|
if self.distributed:
|
||||||
dist.all_reduce_params(runner.model.buffers(), op='mean')
|
all_reduce_params(runner.model.buffers(), op='mean')
|
||||||
|
@ -4,9 +4,7 @@ import datetime
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
from mmengine.device import get_max_cuda_memory, is_cuda_available
|
||||||
|
|
||||||
from mmengine.device import get_max_cuda_memory
|
|
||||||
from mmengine.registry import LOG_PROCESSORS
|
from mmengine.registry import LOG_PROCESSORS
|
||||||
|
|
||||||
|
|
||||||
@ -173,7 +171,7 @@ class LogProcessor:
|
|||||||
log_tag.pop('data_time')
|
log_tag.pop('data_time')
|
||||||
|
|
||||||
# If cuda is available, the max memory occupied should be calculated.
|
# If cuda is available, the max memory occupied should be calculated.
|
||||||
if torch.cuda.is_available():
|
if is_cuda_available():
|
||||||
log_str += f'memory: {self._get_max_memory(runner)} '
|
log_str += f'memory: {self._get_max_memory(runner)} '
|
||||||
# Loop left keys to fill `log_str`.
|
# Loop left keys to fill `log_str`.
|
||||||
if mode in ('train', 'val'):
|
if mode in ('train', 'val'):
|
||||||
|
@ -7,7 +7,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from mmengine import dist
|
from mmengine.dist import get_rank
|
||||||
from mmengine.utils import ManagerMixin
|
from mmengine.utils import ManagerMixin
|
||||||
from mmengine.utils.manager import _accquire_lock, _release_lock
|
from mmengine.utils.manager import _accquire_lock, _release_lock
|
||||||
|
|
||||||
@ -152,7 +152,7 @@ class MMLogger(Logger, ManagerMixin):
|
|||||||
Logger.__init__(self, logger_name)
|
Logger.__init__(self, logger_name)
|
||||||
ManagerMixin.__init__(self, name)
|
ManagerMixin.__init__(self, name)
|
||||||
# Get rank in DDP mode.
|
# Get rank in DDP mode.
|
||||||
rank = dist.get_rank()
|
rank = get_rank()
|
||||||
|
|
||||||
# Config stream_handler. If `rank != 0`. stream_handler can only
|
# Config stream_handler. If `rank != 0`. stream_handler can only
|
||||||
# export ERROR logs.
|
# export ERROR logs.
|
||||||
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
|
||||||
from mmengine.data import BaseDataElement
|
from mmengine.data import BaseDataElement
|
||||||
from mmengine.device.utils import get_device
|
from mmengine.device import get_device
|
||||||
from mmengine.optim import OptimWrapperDict
|
from mmengine.optim import OptimWrapperDict
|
||||||
from mmengine.registry import MODEL_WRAPPERS
|
from mmengine.registry import MODEL_WRAPPERS
|
||||||
from .distributed import MMDistributedDataParallel
|
from .distributed import MMDistributedDataParallel
|
||||||
|
@ -5,9 +5,9 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..device import get_device
|
from mmengine.device import get_device
|
||||||
from ..logging import print_log
|
from mmengine.logging import print_log
|
||||||
from ..utils import TORCH_VERSION, digit_version
|
from mmengine.utils import TORCH_VERSION, digit_version
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -33,10 +33,10 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
|||||||
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
||||||
build_optim_wrapper)
|
build_optim_wrapper)
|
||||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
||||||
LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS,
|
LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS,
|
||||||
PARAM_SCHEDULERS, RUNNERS, VISUALIZERS,
|
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
|
||||||
DefaultScope, count_registered_modules)
|
VISUALIZERS, DefaultScope,
|
||||||
from mmengine.registry.root import LOG_PROCESSORS
|
count_registered_modules)
|
||||||
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
|
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
|
||||||
get_git_hash, is_seq_of, revert_sync_batchnorm,
|
get_git_hash, is_seq_of, revert_sync_batchnorm,
|
||||||
set_multi_processing)
|
set_multi_processing)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from mmengine.utils.misc import is_list_of
|
from mmengine.utils import is_list_of
|
||||||
|
|
||||||
|
|
||||||
def calc_dynamic_intervals(
|
def calc_dynamic_intervals(
|
||||||
|
@ -8,14 +8,16 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mmengine import MMLogger, print_log
|
from mmengine.logging import MMLogger, print_log
|
||||||
|
|
||||||
|
|
||||||
class TestLogger:
|
class TestLogger:
|
||||||
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
||||||
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
||||||
|
|
||||||
@patch('mmengine.dist.get_rank', lambda: 0)
|
# Since `get_rank` has been imported in logger.py, it needs to mock
|
||||||
|
# `logger.get_rank`
|
||||||
|
@patch('mmengine.logging.logger.get_rank', lambda: 0)
|
||||||
def test_init_rank0(self, tmp_path):
|
def test_init_rank0(self, tmp_path):
|
||||||
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
|
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
|
||||||
assert logger.name == 'mmengine'
|
assert logger.name == 'mmengine'
|
||||||
@ -45,7 +47,7 @@ class TestLogger:
|
|||||||
assert logger.instance_name == 'rank0.pkg3'
|
assert logger.instance_name == 'rank0.pkg3'
|
||||||
logging.shutdown()
|
logging.shutdown()
|
||||||
|
|
||||||
@patch('mmengine.dist.get_rank', lambda: 1)
|
@patch('mmengine.logging.logger.get_rank', lambda: 1)
|
||||||
def test_init_rank1(self, tmp_path):
|
def test_init_rank1(self, tmp_path):
|
||||||
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
|
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
|
||||||
tmp_file = tmp_path / 'tmp_file.log'
|
tmp_file = tmp_path / 'tmp_file.log'
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmengine import HistoryBuffer, MessageHub
|
from mmengine.logging import HistoryBuffer, MessageHub
|
||||||
|
|
||||||
|
|
||||||
class NoDeepCopy:
|
class NoDeepCopy:
|
||||||
|
@ -11,8 +11,8 @@ from torch.cuda.amp import GradScaler
|
|||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
from torch.optim import SGD, Adam, Optimizer
|
from torch.optim import SGD, Adam, Optimizer
|
||||||
|
|
||||||
from mmengine import MessageHub, MMLogger
|
|
||||||
from mmengine.dist import all_gather
|
from mmengine.dist import all_gather
|
||||||
|
from mmengine.logging import MessageHub, MMLogger
|
||||||
from mmengine.optim import AmpOptimWrapper, OptimWrapper
|
from mmengine.optim import AmpOptimWrapper, OptimWrapper
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
from mmengine.testing._internal import MultiProcessTestCase
|
from mmengine.testing._internal import MultiProcessTestCase
|
||||||
|
Loading…
x
Reference in New Issue
Block a user