mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* Add build_runner_with_tta and PrepareTTAHook * rename hook file * support build tta runner with runner type * add unit test * Add build_runner_with_tta to index.rst * minor refine * Add runner test cast * Fix unit test * fix unit test * tmp save * pop None if key does not exist * Fix is_model_wrapper and force register class in test_runner * [Fix] Fix is_model_wrapper * destroy group after ut * register module in testcase * pass through unit test * fix as comment * remove breakpoint * remove mmengine/testing/runner_test_cast.py * minor refine * minor refine * minor refine * set default data preprocessor for model * minor refine * minor refine Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * Fix unit test * replace with in ImgDataPreprocessor * Fix as comment * add inference tutorial in advanced tutorial * update index.rst * add tta example * refine tta tutorial * Add english tutorial * add note for build_runner_with_tta * Fix as comment * add examples * remove chinese comment * Update docs/en/advanced_tutorials/test_time_augmentation.md Co-authored-by: RangiLyu <lyuchqi@gmail.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: RangiLyu <lyuchqi@gmail.com>
39 lines
2.2 KiB
Python
39 lines
2.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from mmengine.utils.dl_utils import TORCH_VERSION
|
|
from mmengine.utils.version_utils import digit_version
|
|
from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage,
|
|
MomentumAnnealingEMA, StochasticWeightAverage)
|
|
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
|
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
|
|
from .test_time_aug import BaseTTAModel
|
|
from .utils import (convert_sync_batchnorm, detect_anomalous_params,
|
|
merge_dict, revert_sync_batchnorm, stack_batch)
|
|
from .weight_init import (BaseInit, Caffe2XavierInit, ConstantInit,
|
|
KaimingInit, NormalInit, PretrainedInit,
|
|
TruncNormalInit, UniformInit, XavierInit,
|
|
bias_init_with_prob, caffe2_xavier_init,
|
|
constant_init, initialize, kaiming_init, normal_init,
|
|
trunc_normal_init, uniform_init, update_init_info,
|
|
xavier_init)
|
|
from .wrappers import (MMDistributedDataParallel,
|
|
MMSeparateDistributedDataParallel, is_model_wrapper)
|
|
|
|
__all__ = [
|
|
'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel',
|
|
'StochasticWeightAverage', 'ExponentialMovingAverage',
|
|
'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor',
|
|
'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule',
|
|
'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList',
|
|
'ModuleDict', 'Sequential', 'revert_sync_batchnorm', 'update_init_info',
|
|
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
|
|
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
|
|
'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
|
|
'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
|
|
'Caffe2XavierInit', 'PretrainedInit', 'initialize',
|
|
'convert_sync_batchnorm', 'BaseTTAModel'
|
|
]
|
|
|
|
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
|
from .wrappers import MMFullyShardedDataParallel # noqa:F401
|
|
__all__.append('MMFullyShardedDataParallel')
|