Fix docstring format (#337)
parent
2fd6beb972
commit
6015fd35e5
|
@ -27,7 +27,8 @@ class LogProcessor:
|
|||
custom_cfg (list[dict], optional): Contains multiple log config dict,
|
||||
in which key means the data source name of log and value means the
|
||||
statistic method and corresponding arguments used to count the
|
||||
data source. Defaults to None
|
||||
data source. Defaults to None.
|
||||
|
||||
- If custom_cfg is None, all logs will be formatted via default
|
||||
methods, such as smoothing loss by default window_size. If
|
||||
custom_cfg is defined as a list of config dict, for example:
|
||||
|
@ -35,12 +36,12 @@ class LogProcessor:
|
|||
window_size='global')]. It means the log item ``loss`` will be
|
||||
counted as global mean and additionally logged as ``global_loss``
|
||||
(defined by ``log_name``). If ``log_name`` is not defined in
|
||||
config dict, the original logged key will be overwritten.
|
||||
config dict, the original logged key will be overwritten.
|
||||
|
||||
- The original log item cannot be overwritten twice. Here is
|
||||
an error example:
|
||||
[dict(data_src=loss, method='mean', window_size='global'),
|
||||
dict(data_src=loss, method='mean', window_size='epoch')].
|
||||
dict(data_src=loss, method='mean', window_size='epoch')].
|
||||
Both log config dict in custom_cfg do not have ``log_name`` key,
|
||||
which means the loss item will be overwritten twice.
|
||||
|
||||
|
|
|
@ -246,6 +246,7 @@ def print_log(msg,
|
|||
logger (Logger or str, optional): If the type of logger is
|
||||
``logging.Logger``, we directly use logger to log messages.
|
||||
Some special loggers are:
|
||||
|
||||
- "silent": No message will be printed.
|
||||
- "current": Use latest created logger to log message.
|
||||
- other str: Instance name of logger. The corresponding logger
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils.parrots_wrapper import TORCH_VERSION
|
||||
from mmengine.utils.version_utils import digit_version
|
||||
from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
|
||||
StochasticWeightAverage)
|
||||
from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage,
|
||||
MomentumAnnealingEMA, StochasticWeightAverage)
|
||||
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
|
||||
from .utils import detect_anomalous_params, merge_dict, stack_batch
|
||||
|
@ -10,12 +10,12 @@ from .wrappers import (MMDistributedDataParallel,
|
|||
MMSeparateDistributedDataParallel, is_model_wrapper)
|
||||
|
||||
__all__ = [
|
||||
'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage',
|
||||
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
|
||||
'BaseDataPreprocessor', 'ImgDataPreprocessor',
|
||||
'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
|
||||
'merge_dict', 'detect_anomalous_params', 'ModuleList', 'ModuleDict',
|
||||
'Sequential'
|
||||
'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel',
|
||||
'StochasticWeightAverage', 'ExponentialMovingAverage',
|
||||
'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor',
|
||||
'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule',
|
||||
'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList',
|
||||
'ModuleDict', 'Sequential'
|
||||
]
|
||||
|
||||
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
||||
|
|
|
@ -17,25 +17,28 @@ class BaseAveragedModel(nn.Module):
|
|||
training neural networks. This class implements the averaging process
|
||||
for a model. All subclasses must implement the `avg_func` method.
|
||||
This class creates a copy of the provided module :attr:`model`
|
||||
on the device :attr:`device` and allows computing running averages of the
|
||||
on the :attr:`device` and allows computing running averages of the
|
||||
parameters of the :attr:`model`.
|
||||
|
||||
The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py.
|
||||
|
||||
Different from the `AveragedModel` in PyTorch, we use in-place operation
|
||||
to improve the parameter updating speed, which is about 5 times faster
|
||||
than the non-in-place version.
|
||||
|
||||
In mmengine, we provide two ways to use the model averaging:
|
||||
|
||||
1. Use the model averaging module in hook:
|
||||
We provide an EMAHook to apply the model averaging during training.
|
||||
Add ``custom_hooks=[dict(type='EMAHook')]`` to the config or the runner.
|
||||
The hook is implemented in mmengine/hooks/ema_hook.py
|
||||
We provide an EMAHook to apply the model averaging during training.
|
||||
Add ``custom_hooks=[dict(type='EMAHook')]`` to the config or the runner.
|
||||
The hook is implemented in mmengine/hooks/ema_hook.py
|
||||
|
||||
2. Use the model averaging module directly in the algorithm. Take the ema
|
||||
teacher in semi-supervise as an example:
|
||||
>>> from mmengine.model import ExponentialMovingAverage
|
||||
>>> student = ResNet(depth=50)
|
||||
>>> # use ema model as teacher
|
||||
>>> ema_teacher = ExponentialMovingAverage(student)
|
||||
>>> from mmengine.model import ExponentialMovingAverage
|
||||
>>> student = ResNet(depth=50)
|
||||
>>> # use ema model as teacher
|
||||
>>> ema_teacher = ExponentialMovingAverage(student)
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be averaged.
|
||||
|
@ -134,7 +137,7 @@ class StochasticWeightAverage(BaseAveragedModel):
|
|||
|
||||
@MODELS.register_module()
|
||||
class ExponentialMovingAverage(BaseAveragedModel):
|
||||
"""Implements the exponential moving average (EMA) of the model.
|
||||
r"""Implements the exponential moving average (EMA) of the model.
|
||||
|
||||
All parameters are updated by the formula as below:
|
||||
|
||||
|
@ -145,9 +148,10 @@ class ExponentialMovingAverage(BaseAveragedModel):
|
|||
Args:
|
||||
model (nn.Module): The model to be averaged.
|
||||
momentum (float): The momentum used for updating ema parameter.
|
||||
Ema's parameter are updated with the formula:
|
||||
`averaged_param = (1-momentum) * averaged_param + momentum *
|
||||
source_param`. Defaults to 0.0002.
|
||||
Defaults to 0.0002.
|
||||
Ema's parameter are updated with the formula
|
||||
:math:`averaged\_param = (1-momentum) * averaged\_param +
|
||||
momentum * source\_param`.
|
||||
interval (int): Interval between two updates. Defaults to 1.
|
||||
device (torch.device, optional): If provided, the averaged model will
|
||||
be stored on the :attr:`device`. Defaults to None.
|
||||
|
@ -184,14 +188,15 @@ class ExponentialMovingAverage(BaseAveragedModel):
|
|||
|
||||
@MODELS.register_module()
|
||||
class MomentumAnnealingEMA(ExponentialMovingAverage):
|
||||
"""Exponential moving average (EMA) with momentum annealing strategy.
|
||||
r"""Exponential moving average (EMA) with momentum annealing strategy.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be averaged.
|
||||
momentum (float): The momentum used for updating ema parameter.
|
||||
Ema's parameter are updated with the formula:
|
||||
`averaged_param = (1-momentum) * averaged_param + momentum *
|
||||
source_param`. Defaults to 0.0002.
|
||||
Defaults to 0.0002.
|
||||
Ema's parameter are updated with the formula
|
||||
:math:`averaged\_param = (1-momentum) * averaged\_param +
|
||||
momentum * source\_param`.
|
||||
gamma (int): Use a larger momentum early in training and gradually
|
||||
annealing to a smaller value to update the ema model smoothly. The
|
||||
momentum is calculated as max(momentum, gamma / (gamma + steps))
|
||||
|
|
Loading…
Reference in New Issue