Fix docstring formats (#383)

* update doc formats

* update docstring
pull/384/head
Kai Chen 2020-07-04 00:55:25 +08:00 committed by GitHub
parent a47451b47a
commit 63b7aa31b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 112 additions and 110 deletions

View File

@ -28,6 +28,11 @@ repos:
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/myint/docformatter
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: local
hooks:
- id: clang-format

View File

@ -45,4 +45,4 @@ runner
ops
------
.. automodule:: mmcv.ops
:members:
:members:

View File

@ -145,7 +145,7 @@ img_ = mmcv.impad(img, shape=(1000, 1200), pad_val=[100, 50, 200])
# pad the image on left, right, top, bottom borders with all zeros
img_ = mmcv.impad(img, padding=(10, 20, 30, 40), pad_val=0)
# pad the image on left, right, top, bottom borders with different values
# pad the image on left, right, top, bottom borders with different values
# for three channels.
img_ = mmcv.impad(img, padding=(10, 20, 30, 40), pad_val=[100, 50, 200])

View File

@ -18,7 +18,7 @@ from mmcv.runner import DistSamplerSeedHook, Runner
def accuracy(output, target, topk=(1, )):
"""Computes the precision@k for the specified values of k"""
"""Computes the precision@k for the specified values of k."""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)

View File

@ -2,8 +2,8 @@
from .alexnet import AlexNet
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, ConvModule, GeneralizedAttention,
NonLocal1d, NonLocal2d, NonLocal3d, Scale,
ContextBlock, ConvModule, GeneralizedAttention, HSigmoid,
HSwish, NonLocal1d, NonLocal2d, NonLocal3d, Scale,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, is_norm)
@ -20,7 +20,7 @@ __all__ = [
'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
'build_padding_layer', 'build_upsample_layer', 'build_plugin_layer',
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock',
'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
'get_model_complexity_info'
'HSigmoid', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info'
]

View File

@ -32,12 +32,13 @@ class GeneralizedAttention(nn.Module):
Default: 1.
attention_type (str): A binary indicator string for indicating which
items in generalized empirical_attention module are used.
'1000' indicates 'query and key content' (appr - appr) item,
'0100' indicates 'query content and relative position'
(appr - position) item,
'0010' indicates 'key content only' (bias - appr) item,
'0001' indicates 'relative position only' (bias - position) item.
Default: '1111'.
- '1000' indicates 'query and key content' (appr - appr) item,
- '0100' indicates 'query content and relative position'
(appr - position) item,
- '0010' indicates 'key content only' (bias - appr) item,
- '0001' indicates 'relative position only' (bias - position) item.
"""
_abbr_ = 'gen_attention_block'

View File

@ -5,8 +5,12 @@ from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
class HSwish(nn.Module):
"""Hard Swish Module. Apply the hard swish function:
Hswish(x) = x * ReLU6(x + 3) / 6
"""Hard Swish Module.
This module applies the hard swish function:
.. math::
Hswish(x) = x * ReLU6(x + 3) / 6
Args:
inplace (bool): can optionally do the operation in-place.

View File

@ -73,6 +73,7 @@ def build_norm_layer(cfg, num_features, postfix=''):
Args:
cfg (dict): The norm layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate a norm layer.
- requires_grad (bool, optional): Whether stop gradient updates.
@ -81,10 +82,9 @@ def build_norm_layer(cfg, num_features, postfix=''):
to create named layer.
Returns:
tuple[str, nn.Module]:
name (str): The layer name consisting of abbreviation and postfix,
e.g., bn1, gn.
layer (nn.Module): Created norm layer.
(str, nn.Module): The first element is the layer name consisting of
abbreviation and postfix, e.g., bn1, gn. The second element is the
created norm layer.
"""
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')

View File

@ -31,7 +31,6 @@ def infer_abbr(class_type):
>>> camel2snack("FancyBlock")
'fancy_block'
"""
word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)

View File

@ -52,14 +52,15 @@ def build_upsample_layer(cfg, *args, **kwargs):
Args:
cfg (dict): The upsample layer config, which should contain:
- type (str): Layer type.
- scale_factor (int): Upsample ratio, which is not applicable to
deconv.
- layer args: Args needed to instantiate a upsample layer.
args (argument list): Arguments passed to the `__init__`
method of the corresponding conv layer.
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
args (argument list): Arguments passed to the ``__init__``
method of the corresponding conv layer.
kwargs (keyword arguments): Keyword arguments passed to the
``__init__`` method of the corresponding conv layer.
Returns:
nn.Module: Created upsample layer.

View File

@ -9,7 +9,7 @@ from .utils import constant_init, kaiming_init
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
"""3x3 convolution with padding"""
"""3x3 convolution with padding."""
return nn.Conv2d(
in_planes,
out_planes,
@ -75,8 +75,8 @@ class Bottleneck(nn.Module):
with_cp=False):
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']

View File

@ -45,18 +45,19 @@ def get_model_complexity_info(model,
each layer in a model.
Supported layers are listed as below:
- Convolutions: `nn.Conv1d`, `nn.Conv2d`, `nn.Conv3d`.
- Activations: `nn.ReLU`, `nn.PReLU`, `nn.ELU`, `nn.LeakyReLU`,
`nn.ReLU6`.
- Poolings: `nn.MaxPool1d`, `nn.MaxPool2d`, `nn.MaxPool3d`,
`nn.AvgPool1d`, `nn.AvgPool2d`, `nn.AvgPool3d`,
`nn.AdaptiveMaxPool1d`, `nn.AdaptiveMaxPool2d`,
`nn.AdaptiveMaxPool3d`, `nn.AdaptiveAvgPool1d`,
`nn.AdaptiveAvgPool2d`, `nn.AdaptiveAvgPool3d`.
- BatchNorms: `nn.BatchNorm1d`, `nn.BatchNorm2d`, `nn.BatchNorm3d`.
- Linear: `nn.Linear`.
- Deconvolution: `nn.ConvTranspose2d`.
- Upsample: `nn.Upsample`.
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
``nn.ReLU6``.
- Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
``nn.BatchNorm3d``.
- Linear: ``nn.Linear``.
- Deconvolution: ``nn.ConvTranspose2d``.
- Upsample: ``nn.Upsample``.
Args:
model (nn.Module): The model for complexity calculation.
@ -69,11 +70,11 @@ def get_model_complexity_info(model,
method that generates input. otherwise, it will generate a random
tensor with input shape to calculate FLOPs. Default: None.
flush (bool): same as that in :func:`print`. Default: False.
ost (stream): same as `file` param in :func:`print`.
ost (stream): same as ``file`` param in :func:`print`.
Default: sys.stdout.
Returns:
tuple[float | str]: If `as_strings` is set to True, it will return
tuple[float | str]: If ``as_strings`` is set to True, it will return
FLOPs and parameter counts in a string format. otherwise, it will
return those in a float number format.
"""
@ -352,7 +353,7 @@ def start_flops_count(self):
"""Activate the computation of mean flops consumption per image.
A method to activate the computation of mean flops consumption per image.
which will be available after `add_flops_counting_methods()` is called on
which will be available after ``add_flops_counting_methods()`` is called on
a desired net object. It should be called before running the network.
"""
add_batch_counter_hook_function(self)
@ -374,9 +375,9 @@ def start_flops_count(self):
def stop_flops_count(self):
"""Stop computing the mean flops consumption per image.
A method to stop computing the mean flops consumption per image, which
will be available after `add_flops_counting_methods()` is called on a
desired net object. It can be called to pause the computation whenever.
A method to stop computing the mean flops consumption per image, which will
be available after ``add_flops_counting_methods()`` is called on a desired
net object. It can be called to pause the computation whenever.
"""
remove_batch_counter_hook_function(self)
self.apply(remove_flops_counter_hook_function)
@ -385,8 +386,8 @@ def stop_flops_count(self):
def reset_flops_count(self):
"""Reset statistics computed so far.
A method to Reset computed statistics, which will be available
after `add_flops_counting_methods()` is called on a desired net object.
A method to Reset computed statistics, which will be available after
`add_flops_counting_methods()` is called on a desired net object.
"""
add_batch_counter_variables_or_reset(self)
self.apply(add_flops_counter_variable_or_reset)

View File

@ -61,6 +61,6 @@ def caffe2_xavier_init(module, bias=0):
def bias_init_with_prob(prior_prob):
""" initialize conv/fc bias value according to giving probablity"""
"""initialize conv/fc bias value according to giving probablity."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init

View File

@ -8,7 +8,7 @@ from .utils import constant_init, kaiming_init, normal_init
def conv3x3(in_planes, out_planes, dilation=1):
"""3x3 convolution with padding"""
"""3x3 convolution with padding."""
return nn.Conv2d(
in_planes,
out_planes,

View File

@ -6,8 +6,8 @@ from abc import ABCMeta, abstractmethod
class BaseStorageBackend(metaclass=ABCMeta):
"""Abstract class of storage backends.
All backends need to implement two apis: `get()` and `get_text()`.
`get()` reads the file as a byte stream and `get_text()` reads the file
All backends need to implement two apis: ``get()`` and ``get_text()``.
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
as texts.
"""
@ -25,8 +25,8 @@ class CephBackend(BaseStorageBackend):
Args:
path_mapping (dict|None): path mapping dict from local path to Petrel
path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will
be replaced by `dst`. Default: None.
path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
will be replaced by ``dst``. Default: None.
"""
def __init__(self, path_mapping=None):

View File

@ -336,6 +336,7 @@ def impad(img,
areas when padding_mode is 'constant'. Default: 0.
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Default: constant.
- constant: pads with a constant value, this value is specified
with pad_val.
- edge: pads with the last value at the edge of the image.
@ -370,8 +371,8 @@ def impad(img,
elif isinstance(padding, numbers.Number):
padding = (padding, padding, padding, padding)
else:
raise ValueError("Padding must be a int or a 2, or 4 element tuple."
f"But received {padding}")
raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
f'But received {padding}')
# check padding mode
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']

View File

@ -70,7 +70,7 @@ def _jpegflag(flag='color', channel_order='bgr'):
def _pillow2array(img, flag='color', channel_order='bgr'):
"""Convert a pillow image to numpy array
"""Convert a pillow image to numpy array.
Args:
img (:obj:`PIL.Image.Image`): The image loaded using PIL
@ -215,7 +215,7 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file
"""Write image to file.
Args:
img (ndarray): Image array to be written.

View File

@ -9,7 +9,7 @@ except ImportError:
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
"""Convert tensor to 3-channel images
"""Convert tensor to 3-channel images.
Args:
tensor (torch.Tensor): Tensor that contains multiple images, shape (

View File

@ -53,7 +53,7 @@ def imdenormalize(img, mean, std, to_bgr=True):
def iminvert(img):
"""Invert (negate) an image
"""Invert (negate) an image.
Args:
img (ndarray): Image to be inverted.

View File

@ -240,11 +240,12 @@ class DeformConv2dPack(DeformConv2d):
The offset tensor is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
The spatial arrangement is like:
```
(x0, y0) (x1, y1) (x2, y2)
(x3, y3) (x4, y4) (x5, y5)
(x6, y6) (x7, y7) (x8, y8)
```
.. code:: text
(x0, y0) (x1, y1) (x2, y2)
(x3, y3) (x4, y4) (x5, y5)
(x6, y6) (x7, y7) (x8, y8)
Args:
in_channels (int): Same as nn.Conv2d.

View File

@ -216,7 +216,7 @@ def nms_match(dets, iou_threshold):
"""Matched dets into different groups by NMS.
NMS match is Similar to NMS but when a bbox is suppressed, nms match will
record the indice of supporessed bbox and form a group with the indice of
record the indice of suppressed bbox and form a group with the indice of
kept bbox. In each group, indice is sorted as score order.
Arguments:
@ -224,9 +224,9 @@ def nms_match(dets, iou_threshold):
iou_thr (float): IoU thresh for NMS.
Returns:
List[Tensor | ndarray]: The outer list corresponds different matched
group, the inner Tensor corresponds the indices for a group in
score order.
List[torch.Tensor | np.ndarray]: The outer list corresponds different
matched group, the inner Tensor corresponds the indices for a group
in score order.
"""
if dets.shape[0] == 0:
matched = []

View File

@ -134,9 +134,9 @@ def rel_roi_point_to_rel_img_point(rois,
def point_sample(input, points, align_corners=False, **kwargs):
"""A wrapper around :function:`grid_sample` to support 3D point_coords
tensors Unlike :function:`torch.nn.functional.grid_sample` it assumes
point_coords to lie inside [0, 1] x [0, 1] square.
"""A wrapper around :func:`grid_sample` to support 3D point_coords tensors
Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
lie inside ``[0, 1] x [0, 1]`` square.
Args:
input (Tensor): Feature map, shape (N, C, H, W).

View File

@ -4,8 +4,7 @@ from torch.nn.parallel._functions import _get_stream
def scatter(input, devices, streams=None):
"""Scatters tensor across multiple GPUs.
"""
"""Scatters tensor across multiple GPUs."""
if streams is None:
streams = [None] * len(devices)

View File

@ -43,7 +43,7 @@ def scatter(inputs, target_gpus, dim=0):
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
"""Scatter with support for kwargs dictionary"""
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):

View File

@ -27,5 +27,5 @@ __all__ = [
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only',
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor', 'IterLoader',
'IterBasedRunner', 'set_random_seed'
'set_random_seed'
]

View File

@ -247,7 +247,7 @@ class BaseRunner(metaclass=ABCMeta):
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
priority (See :cls:`Priority` for details of priorities).
priority (See :class:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same
order as they are registered.

View File

@ -103,8 +103,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def load_url_dist(url, model_dir=None):
""" In distributed setting, this function only download checkpoint at
local rank 0 """
"""In distributed setting, this function only download checkpoint at local
rank 0."""
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:

View File

@ -161,7 +161,7 @@ class EpochBasedRunner(BaseRunner):
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner"""
"""Deprecated name of EpochBasedRunner."""
def __init__(self, *args, **kwargs):
warnings.warn(

View File

@ -5,7 +5,7 @@ from .hook import HOOKS, Hook
class LrUpdaterHook(Hook):
"""LR Scheduler in MMCV
"""LR Scheduler in MMCV.
Args:
by_epoch (bool): LR changes epoch by epoch
@ -325,7 +325,7 @@ def get_position_from_periods(iteration, cumulative_periods):
@HOOKS.register_module()
class CyclicLrUpdaterHook(LrUpdaterHook):
"""Cyclic LR Scheduler
"""Cyclic LR Scheduler.
Implement the cyclical learning rate policy (CLR) described in
https://arxiv.org/pdf/1506.01186.pdf
@ -341,7 +341,6 @@ class CyclicLrUpdaterHook(LrUpdaterHook):
step_ratio_up (float): The ratio of the increasing process of LR in
the total cycle.
by_epoch (bool): Whether to update LR by epoch.
"""
def __init__(self,

View File

@ -128,7 +128,7 @@ class CosineAnealingMomentumUpdaterHook(MomentumUpdaterHook):
@HOOKS.register_module()
class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
"""Cyclic momentum Scheduler
"""Cyclic momentum Scheduler.
Implemet the cyclical momentum scheduler policy described in
https://arxiv.org/pdf/1708.07120.pdf
@ -143,7 +143,6 @@ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
step_ratio_up (float): The ratio of the increasing process of momentum
in the total cycle.
by_epoch (bool): Whether to update momentum by epoch.
"""
def __init__(self,

View File

@ -31,7 +31,7 @@ class LogBuffer:
self.n_history[key].append(count)
def average(self, n=0):
"""Average latest n values or all values"""
"""Average latest n values or all values."""
assert n >= 0
for key in self.val_history:
values = np.array(self.val_history[key][-n:])

View File

@ -44,8 +44,11 @@ class DefaultOptimizerConstructor:
model (:obj:`nn.Module`): The model with parameters to be optimized.
optimizer_cfg (dict): The config dict of the optimizer.
Positional fields are
- `type`: class name of the optimizer.
Optional fields are
- any arguments of the corresponding optimizer type, e.g.,
lr, weight_decay, momentum, etc.
paramwise_cfg (dict, optional): Parameter-wise options.

View File

@ -78,7 +78,6 @@ class Config:
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
@ -180,8 +179,7 @@ class Config:
@staticmethod
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)
"""
"""Generate argparser from config file automatically (experimental)"""
partial_parser = ArgumentParser(description=description)
partial_parser.add_argument('config', help='config file path')
cfg_file = partial_parser.parse_known_args()[0].config
@ -356,7 +354,7 @@ class Config:
mmcv.dump(cfg_dict, file)
def merge_from_dict(self, options):
"""Merge list into cfg_dict
"""Merge list into cfg_dict.
Merge the dict parsed by MultipleKVAction into this cfg.

View File

@ -8,7 +8,7 @@ from .timer import Timer
class ProgressBar:
"""A progress bar which can print the progress"""
"""A progress bar which can print the progress."""
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
self.task_num = task_num
@ -176,7 +176,8 @@ def track_parallel_progress(func,
def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
"""Track the progress of tasks iteration or enumeration with a progress bar.
"""Track the progress of tasks iteration or enumeration with a progress
bar.
Tasks are yielded with a simple for-loop.

View File

@ -200,7 +200,7 @@ class VideoReader:
start=0,
max_num=0,
show_progress=True):
"""Convert a video to frame images
"""Convert a video to frame images.
Args:
frame_dir (str): Output directory to store all the frame images.
@ -282,7 +282,7 @@ def frames2video(frame_dir,
start=0,
end=0,
show_progress=True):
"""Read the frame images from a directory and join them as a video
"""Read the frame images from a directory and join them as a video.
Args:
frame_dir (str): The directory containing video frames.

View File

@ -139,7 +139,7 @@ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
"""Use flow to warp img
"""Use flow to warp img.
Args:
img (ndarray, float or uint8): Image to be warped.

View File

@ -14,9 +14,8 @@ from Cython.Distutils import build_ext as build_cmd # NOQA: E402 # isort:skip
def choose_requirement(primary, secondary):
"""If some version of primary requirement installed, return primary,
else return secondary.
"""
"""If some version of primary requirement installed, return primary, else
return secondary."""
try:
name = re.split(r'[!<>=]', primary)[0]
get_distribution(name)
@ -40,8 +39,7 @@ def get_version():
def parse_requirements(fname='requirements.txt', with_version=True):
"""
Parse the package dependencies listed in a requirements file but strips
"""Parse the package dependencies listed in a requirements file but strips
specific versioning information.
Args:
@ -60,9 +58,7 @@ def parse_requirements(fname='requirements.txt', with_version=True):
require_fpath = fname
def parse_line(line):
"""
Parse information from a line in a requirements text file
"""
"""Parse information from a line in a requirements text file."""
if line.startswith('-r '):
# Allow specifying requirements in other files
target = line.split(' ')[1]

View File

@ -1,10 +1,8 @@
"""
Tests the hooks with runners.
"""Tests the hooks with runners.
CommandLine:
pytest tests/test_hooks.py
xdoctest tests/test_hooks.py zero
"""
import logging
import os.path as osp
@ -49,9 +47,7 @@ def test_pavi_hook():
def test_momentum_runner_hook():
"""
xdoctest -m tests/test_hooks.py test_momentum_runner_hook
"""
"""xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()
@ -99,9 +95,7 @@ def test_momentum_runner_hook():
def test_cosine_runner_hook():
"""
xdoctest -m tests/test_hooks.py test_cosine_runner_hook
"""
"""xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner()