mmengine/mmengine/model/__init__.py
Mashiro a9b6753fbe
Make TTAModel compatible with FSDP (#611)
* Add build_runner_with_tta and PrepareTTAHook

* rename hook file

* support build tta runner with runner type

* add unit test

* Add build_runner_with_tta to index.rst

* minor refine

* Add runner test cast

* Fix unit test

* fix unit test

* tmp save

* pop None if key does not exist

* Fix is_model_wrapper and force register class in test_runner

* [Fix] Fix is_model_wrapper

* destroy group after ut

* register module in testcase

* pass through unit test

* fix as comment

* remove breakpoint

* remove mmengine/testing/runner_test_cast.py

* minor refine

* minor refine

* minor refine

* set default data preprocessor for model

* minor refine

* minor refine

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix lint

* Fix unit test

* replace  with  in ImgDataPreprocessor

* Fix as comment

* add inference tutorial in advanced tutorial

* update index.rst

* add tta example

* refine tta tutorial

* Add english tutorial

* add note for build_runner_with_tta

* Fix as comment

* add examples

* remove chinese comment

* Update docs/en/advanced_tutorials/test_time_augmentation.md

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: RangiLyu <lyuchqi@gmail.com>
2022-12-27 21:39:43 +08:00

39 lines
2.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version
from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage,
MomentumAnnealingEMA, StochasticWeightAverage)
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .test_time_aug import BaseTTAModel
from .utils import (convert_sync_batchnorm, detect_anomalous_params,
merge_dict, revert_sync_batchnorm, stack_batch)
from .weight_init import (BaseInit, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit,
TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
trunc_normal_init, uniform_init, update_init_info,
xavier_init)
from .wrappers import (MMDistributedDataParallel,
MMSeparateDistributedDataParallel, is_model_wrapper)
__all__ = [
'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel',
'StochasticWeightAverage', 'ExponentialMovingAverage',
'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor',
'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule',
'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList',
'ModuleDict', 'Sequential', 'revert_sync_batchnorm', 'update_init_info',
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
'Caffe2XavierInit', 'PretrainedInit', 'initialize',
'convert_sync_batchnorm', 'BaseTTAModel'
]
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
from .wrappers import MMFullyShardedDataParallel # noqa:F401
__all__.append('MMFullyShardedDataParallel')