[Docs] Fix the format of the docstring (#1573)

* [Docs] Fix the format of docstring

* fix format
This commit is contained in:
Zaida Zhou 2021-12-09 22:15:52 +08:00 committed by GitHub
parent 53c1b2fe91
commit 222f38075b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 136 additions and 138 deletions

View File

@ -83,6 +83,7 @@ def build_activation_layer(cfg):
Args:
cfg (dict): The activation layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate an activation layer.

View File

@ -83,8 +83,8 @@ def build_norm_layer(cfg, num_features, postfix=''):
to create named layer.
Returns:
(str, nn.Module): The first element is the layer name consisting of
abbreviation and postfix, e.g., bn1, gn. The second element is the
tuple[str, nn.Module]: The first element is the layer name consisting
of abbreviation and postfix, e.g., bn1, gn. The second element is the
created norm layer.
"""
if not isinstance(cfg, dict):

View File

@ -57,15 +57,15 @@ def build_plugin_layer(cfg, postfix='', **kwargs):
Args:
cfg (None or dict): cfg should contain:
type (str): identify plugin layer type.
layer args: args needed to instantiate a plugin layer.
- type (str): identify plugin layer type.
- layer args: args needed to instantiate a plugin layer.
postfix (int, str): appended into norm abbreviation to
create named layer. Default: ''.
Returns:
tuple[str, nn.Module]:
name (str): abbreviation + postfix
layer (nn.Module): created plugin layer
tuple[str, nn.Module]: The first one is the concatenation of
abbreviation and postfix. The second is the created plugin layer.
"""
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')

View File

@ -48,8 +48,8 @@ def get_model_complexity_info(model,
Supported layers are listed as below:
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
``nn.ReLU6``.
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``,
``nn.LeakyReLU``, ``nn.ReLU6``.
- Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,

View File

@ -464,13 +464,13 @@ def impad(img,
- constant: pads with a constant value, this value is specified
with pad_val.
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the
last value on the edge. For example, padding [1, 2, 3, 4]
with 2 elements on both sides in reflect mode will result
in [3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with
2 elements on both sides in symmetric mode will result in
- reflect: pads with reflection of image without repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with 2
elements on both sides in reflect mode will result in
[3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last value
on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
Returns:

View File

@ -25,11 +25,10 @@ class _DynamicScatter(Function):
'mean'. Default: 'max'.
Returns:
tuple[torch.Tensor]: tuple[torch.Tensor]: A tuple contains two
elements. The first one is the voxel features with shape [M, C]
which are respectively reduced from input features that share
the same voxel coordinates . The second is voxel coordinates
with shape [M, ndim].
tuple[torch.Tensor]: A tuple contains two elements. The first one
is the voxel features with shape [M, C] which are respectively
reduced from input features that share the same voxel coordinates.
The second is voxel coordinates with shape [M, ndim].
"""
results = ext_module.dynamic_point_to_voxel_forward(
feats, coors, reduce_type)
@ -89,11 +88,10 @@ class DynamicScatter(nn.Module):
multi-dim voxel index) of each points.
Returns:
tuple[torch.Tensor]: tuple[torch.Tensor]: A tuple contains two
elements. The first one is the voxel features with shape [M, C]
which are respectively reduced from input features that share
the same voxel coordinates . The second is voxel coordinates
with shape [M, ndim].
tuple[torch.Tensor]: A tuple contains two elements. The first one
is the voxel features with shape [M, C] which are respectively
reduced from input features that share the same voxel coordinates.
The second is voxel coordinates with shape [M, ndim].
"""
reduce = 'mean' if self.average_points else 'max'
return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
@ -107,11 +105,10 @@ class DynamicScatter(nn.Module):
multi-dim voxel index) of each points.
Returns:
tuple[torch.Tensor]:tuple[torch.Tensor]: A tuple contains two
elements. The first one is the voxel features with shape [M, C]
which are respectively reduced from input features that share
the same voxel coordinates . The second is voxel coordinates
with shape [M, ndim].
tuple[torch.Tensor]: A tuple contains two elements. The first one
is the voxel features with shape [M, C] which are respectively
reduced from input features that share the same voxel coordinates.
The second is voxel coordinates with shape [M, ndim].
"""
if coors.size(-1) == 3:
return self.forward_single(points, coors)

View File

@ -19,12 +19,11 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
- ``init_cfg``: the config to control the initialization.
- ``init_weights``: The function of parameter
initialization and recording initialization
information.
- ``_params_init_info``: Used to track the parameter
initialization information. This attribute only
exists during executing the ``init_weights``.
- ``init_weights``: The function of parameter initialization and recording
initialization information.
- ``_params_init_info``: Used to track the parameter initialization
information. This attribute only exists during executing the
``init_weights``.
Args:
init_cfg (dict, optional): Initialization config dict.

View File

@ -207,8 +207,8 @@ class BaseRunner(metaclass=ABCMeta):
Returns:
list[float] | dict[str, list[float]]: Current learning rates of all
param groups. If the runner has a dict of optimizers, this
method will return a dict.
param groups. If the runner has a dict of optimizers, this method
will return a dict.
"""
if isinstance(self.optimizer, torch.optim.Optimizer):
lr = [group['lr'] for group in self.optimizer.param_groups]
@ -226,8 +226,8 @@ class BaseRunner(metaclass=ABCMeta):
Returns:
list[float] | dict[str, list[float]]: Current momentums of all
param groups. If the runner has a dict of optimizers, this
method will return a dict.
param groups. If the runner has a dict of optimizers, this method
will return a dict.
"""
def _get_momentum(optimizer):
@ -287,7 +287,7 @@ class BaseRunner(metaclass=ABCMeta):
hook_cfg (dict): Hook config. It should have at least keys 'type'
and 'priority' indicating its type and priority.
Notes:
Note:
The specific hook class to register should not use 'type' and
'priority' arguments during initialization.
"""

View File

@ -13,8 +13,8 @@ class EMAHook(Hook):
.. math::
\text{Xema\_{t+1}} = (1 - \text{momentum}) \times
\text{Xema\_{t}} + \text{momentum} \times X_t
Xema\_{t+1} = (1 - \text{momentum}) \times
Xema\_{t} + \text{momentum} \times X_t
Args:
momentum (float): The momentum used for updating ema parameter.

View File

@ -12,19 +12,21 @@ class NeptuneLoggerHook(LoggerHook):
Args:
init_kwargs (dict): a dict contains the initialization keys as below:
- project (str): Name of a project in a form of
namespace/project_name. If None, the value of
NEPTUNE_PROJECT environment variable will be taken.
- api_token (str): Users API token.
If None, the value of NEPTUNE_API_TOKEN environment
variable will be taken. Note: It is strongly recommended
to use NEPTUNE_API_TOKEN environment variable rather than
placing your API token in plain text in your source code.
- name (str, optional, default is 'Untitled'): Editable name of
the run. Name is displayed in the run's Details and in
Runs table as a column.
Check https://docs.neptune.ai/api-reference/neptune#init for
more init arguments.
namespace/project_name. If None, the value of NEPTUNE_PROJECT
environment variable will be taken.
- api_token (str): Users API token. If None, the value of
NEPTUNE_API_TOKEN environment variable will be taken. Note: It is
strongly recommended to use NEPTUNE_API_TOKEN environment
variable rather than placing your API token in plain text in your
source code.
- name (str, optional, default is 'Untitled'): Editable name of the
run. Name is displayed in the run's Details and in Runs table as
a column.
Check https://docs.neptune.ai/api-reference/neptune#init for more
init arguments.
interval (int): Logging interval (every k iterations).
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.

View File

@ -344,7 +344,7 @@ class Config:
config str. Only py/yml/yaml/json type are supported now!
Returns:
obj:`Config`: Config obj.
:obj:`Config`: Config obj.
"""
if file_format not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
@ -561,7 +561,7 @@ class Config:
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
# Merge list element
>>> # Merge list element
>>> cfg = Config(dict(pipeline=[
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})

View File

@ -40,7 +40,7 @@ def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
"""Scan a directory to find the interested files.
Args:
dir_path (str | obj:`Path`): Path of the directory.
dir_path (str | :obj:`Path`): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the

View File

@ -59,6 +59,7 @@ class Registry:
"""A registry to map strings to classes.
Registered object could be built from registry.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
@ -128,16 +129,15 @@ class Registry:
The name of the package where registry is defined will be returned.
Example:
# in mmdet/models/backbone/resnet.py
>>> # in mmdet/models/backbone/resnet.py
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
The scope of ``ResNet`` will be ``mmdet``.
Returns:
scope (str): The inferred scope name.
str: The inferred scope name.
"""
# inspect.stack() trace where this function is called, the index-2
# indicates the frame where `infer_scope()` is called
@ -158,8 +158,8 @@ class Registry:
None, 'ResNet'
Return:
scope (str, None): The first scope.
key (str): The remaining key.
tuple[str | None, str]: The former element is the first scope of
the key, which can be ``None``. The latter is the remaining key.
"""
split_index = key.find('.')
if split_index != -1:

View File

@ -12,8 +12,7 @@ class TimerError(Exception):
class Timer:
"""A flexible Timer class.
:Example:
Examples:
>>> import time
>>> import mmcv
>>> with mmcv.Timer():
@ -64,7 +63,8 @@ class Timer:
def since_start(self):
"""Total time since the timer is started.
Returns (float): Time in seconds.
Returns:
float: Time in seconds.
"""
if not self._is_running:
raise TimerError('timer is not running')
@ -77,7 +77,8 @@ class Timer:
Either :func:`since_start` or :func:`since_last_check` is a checking
operation.
Returns (float): Time in seconds.
Returns:
float: Time in seconds.
"""
if not self._is_running:
raise TimerError('timer is not running')
@ -95,8 +96,7 @@ def check_time(timer_id):
This method is suitable for running a task on a list of items. A timer will
be registered when the method is called for the first time.
:Example:
Examples:
>>> import time
>>> import mmcv
>>> for i in range(1, 6):
@ -109,7 +109,7 @@ def check_time(timer_id):
5.000
Args:
timer_id (str): Timer identifier.
str: Timer identifier.
"""
if timer_id not in _g_timers:
_g_timers[timer_id] = Timer()

View File

@ -50,8 +50,7 @@ class VideoReader:
the second time, there is no need to decode again if it is stored in the
cache.
:Example:
Examples:
>>> import mmcv
>>> v = mmcv.VideoReader('sample.mp4')
>>> len(v) # get the total frame number with `len()`