Merge pull request #14 from open-mmlab/docs

Draft documentation
pull/16/head
Kai Chen 2018-10-06 15:46:00 +08:00 committed by GitHub
commit f4550cd319
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 862 additions and 25 deletions

19
docs/Makefile 100644
View File

@ -0,0 +1,19 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

BIN
docs/_static/progress.gif vendored 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

38
docs/api.rst 100644
View File

@ -0,0 +1,38 @@
API
====
fileio
--------------
.. automodule:: mmcv.fileio
:members:
image
--------------
.. automodule:: mmcv.image
:members:
video
--------------
.. automodule:: mmcv.video
:members:
arraymisc
--------------
.. automodule:: mmcv.arraymisc
:members:
visualization
--------------
.. automodule:: mmcv.visualization
:members:
utils
--------------
.. automodule:: mmcv.utils
:members:
runner
--------------
.. automodule:: mmcv.runner
:members:

170
docs/conf.py 100644
View File

@ -0,0 +1,170 @@
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
import mmcv # noqa: E402
# -- Project information -----------------------------------------------------
project = 'mmcv'
copyright = '2018, Kai Chen'
author = 'Kai Chen'
# The short X.Y version
version = mmcv.__version__
# The full version, including alpha/beta/rc tags
release = mmcv.__version__
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
]
autodoc_mock_imports = ['cv2', 'torch']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = ['.rst', '.md']
source_parsers = {'.md': 'recommonmark.parser.CommonMarkParser'}
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'mmcvdoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'mmcv.tex', 'mmcv Documentation', 'Kai Chen', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, 'mmcv', 'mmcv Documentation', [author], 1)]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'mmcv', 'mmcv Documentation', author, 'mmcv',
'One line description of project.', 'Miscellaneous'),
]
# -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info.
epub_title = project
# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''
# A unique identification for the text.
#
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
# -- Extension configuration -------------------------------------------------

148
docs/image.md 100644
View File

@ -0,0 +1,148 @@
## Image
This module provides some image processing methods, which requires `opencv` to
to be installed.
### Read/Write/Show
To read or write images files, use `imread` or `imwrite`.
```python
import mmcv
img = mmcv.imread('test.jpg')
img = mmcv.imread('test.jpg', flag='grayscale')
img_ = mmcv.imread(img) # nothing will happen, img_ = img
mmcv.imwrite(img, 'out.jpg')
```
To read images from bytes
```python
with open('test.jpg', 'rb') as f:
data = f.read()
img = mmcv.imfrombytes(data)
```
To show an image file or a loaded image
```python
mmcv.imshow('tests/data/color.jpg')
# this is equivalent to
for i in range(10):
img = np.random.randint(256, size=(100, 100, 3), dtype=np.uint8)
mmcv.imshow(img, win_name='test image', wait_time=200)
```
### Color space conversion
Supported conversion methods:
- bgr2gray
- gray2bgr
- bgr2rgb
- rgb2bgr
- bgr2hsv
- hsv2bgr
```python
img = mmcv.imread('tests/data/color.jpg')
img1 = mmcv.bgr2rgb(img)
img2 = mmcv.rgb2gray(img1)
img3 = mmcv.bgr2hsv(img)
```
### Resize
There are three resize methods. All `imresize_*` methods have an argument `return_scale`,
if this argument is `False`, then the return value is merely the resized image, otherwise
is a tuple `(resized_img, scale)`.
```python
# resize to a given size
mmcv.imresize(img, (1000, 600), return_scale=True)
# resize to the same size of another image
mmcv.imresize_like(img, dst_img, return_scale=False)
# resize by a ratio
mmcv.imrescale(img, 0.5)
# resize so that the max edge no longer than 1000, short edge no longer than 800
# without changing the aspect ratio
mmcv.imrescale(img, (1000, 800))
```
### Rotate
To rotate an image by some angle, use `imrotate`. The center can be specified,
which is the center of original image by default. There are two modes of rotating,
one is to keep the image size unchanged so that some parts of the image will be
cropped after rotating, the other is to extend the image size to fit the rotated
image.
```python
img = mmcv.imread('tests/data/color.jpg')
# rotate the image clockwise by 30 degrees.
img_ = mmcv.imrotate(img, 30)
# rotate the image counterclockwise by 90 degrees.
img_ = mmcv.imrotate(img, -90)
# rotate the image clockwise by 30 degrees, and rescale it by 1.5x at the same time.
img_ = mmcv.imrotate(img, 30, scale=1.5)
# rotate the image clockwise by 30 degrees, with (100, 100) as the center.
img_ = mmcv.imrotate(img, 30, center=(100, 100))
# rotate the image clockwise by 30 degrees, and extend the image size.
img_ = mmcv.imrotate(img, 30, auto_bound=True)
```
### Flip
To flip an image, use `imflip`.
```python
img = mmcv.imread('tests/data/color.jpg')
# flip the image horizontally
mmcv.imflip(img)
# flip the image vertically
mmcv.imflip(img, direction='vertical')
```
### Crop
`imcrop` can crop the image with one or some regions, represented as (x1, y1, x2, y2).
```python
import mmcv
import numpy as np
img = mmcv.read_img('tests/data/color.jpg')
# crop the region (10, 10, 100, 120)
bboxes = np.array([10, 10, 100, 120])
patch = mmcv.crop_img(img, bboxes)
# crop two regions (10, 10, 100, 120) and (0, 0, 50, 50)
bboxes = np.array([[10, 10, 100, 120], [0, 0, 50, 50]])
patches = mmcv.crop_img(img, bboxes)
# crop two regions, and rescale the patches by 1.2x
patches = mmcv.crop_img(img, bboxes, scale_ratio=1.2)
```
### Padding
There are two methods `impad` and `impad_to_multiple` to pad an image to the
specific size with given values.
```python
img = mmcv.read_img('tests/data/color.jpg')
# pad the image to (1000, 1200) with all zeros
img_ = mmcv.pad_img(img, (1000, 1200), pad_val=0)
# pad the image to (1000, 1200) with different values for three channels.
img_ = mmcv.pad_img(img, (1000, 1200), pad_val=[100, 50, 200])
# pad an image so that each edge is a multiple of some value.
img_ = mmcv.impad_to_multiple(img, 32)
```

22
docs/index.rst 100644
View File

@ -0,0 +1,22 @@
Welcome to mmcv's documentation!
==================================
.. toctree::
:maxdepth: 2
intro.md
io.md
image.md
video.md
visualization.md
utils.md
runner.md
api.rst
Indices and tables
==================
* :ref:`genindex`
* :ref:`search`

33
docs/intro.md 100644
View File

@ -0,0 +1,33 @@
## Introduction
`mmcv` is a foundational python library for computer vision research and supports many
research projects in MMLAB, such as [mmdetection](https://github.com/open-mmlab/mmdetection).
It provides the following functionalities.
- Universal IO APIs
- Image processing
- Video processing
- Image and annotation visualization
- Useful utilities (progress bar, timer, ...)
- PyTorch runner with hooking machanism
- Various CNN architectures
See the [documentation](http://mmcv.readthedocs.io/en/latest) for more features and usage.
### Installation
Try and start with
```shell
pip install mmcv
```
If you want to install from source
```shell
git clone https://github.com/open-mmlab/mmcv.git
cd mmcv
pip install . # (add "-e" if you want to develop or modify the codes)
```

120
docs/io.md 100644
View File

@ -0,0 +1,120 @@
## File IO
This module provides two universal API to load and dump files of different formats.
### Load and dump data
`mmcv` provides a universal api for loading and dumping data, currently
supported formats are json, yaml and pickle.
```python
import mmcv
# load data from a file
data = mmcv.load('test.json')
data = mmcv.load('test.yaml')
data = mmcv.load('test.pkl')
# load data from a file-like object
with open('test.json', 'r') as f:
data = mmcv.load(f)
# dump data to a string
json_str = mmcv.dump(data, format='json')
# dump data to a file with a filename (infer format from file extension)
mmcv.dump(data, 'out.pkl')
# dump data to a file with a file-like object
with open('test.yaml', 'w') as f:
data = mmcv.dump(data, f, format='yaml')
```
It is also very convenient to extend the api to support more file formats.
All you need to do is to write a file handler inherited from `BaseFileHandler`
and register it with one or several file formats.
You need to implement at least 3 methods.
```python
import mmcv
# To register multiple file formats, a list can be used as the argument.
# @mmcv.register_handler(['txt', 'log'])
@mmcv.register_handler('txt')
class TxtHandler1(mmcv.BaseFileHandler):
def load_from_fileobj(self, file):
return file.read()
def dump_to_fileobj(self, obj, file):
file.write(str(obj))
def dump_to_str(self, obj, **kwargs):
return str(obj)
```
Here is an example of `PickleHandler`.
```python
from six.moves import cPickle as pickle
class PickleHandler(mmcv.BaseFileHandler):
def load_from_fileobj(self, file, **kwargs):
return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path(
filepath, mode='rb', **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('protocol', 2)
return pickle.dumps(obj, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault('protocol', 2)
pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path(
obj, filepath, mode='wb', **kwargs)
```
### Load a text file as a list or dict
For example `a.txt` is a text file with 5 lines.
```
a
b
c
d
e
```
Then use `list_from_file` to load the list from a.txt.
```python
>>> mmcv.list_from_file('a.txt')
['a', 'b', 'c', 'd', 'e']
>>> mmcv.list_from_file('a.txt', offset=2)
['c', 'd', 'e']
>>> mmcv.list_from_file('a.txt', max_num=2)
['a', 'b']
>>> mmcv.list_from_file('a.txt', prefix='/mnt/')
['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e']
```
For example `b.txt` is a text file with 5 lines.
```
1 cat
2 dog cow
3 panda
```
Then use `dict_from_file` to load the list from a.txt.
```python
>>> mmcv.dict_from_file('b.txt')
{'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
>>> mmcv.dict_from_file('b.txt', key_type=int)
{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
```

35
docs/make.bat 100644
View File

@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
:end
popd

4
docs/runner.md 100644
View File

@ -0,0 +1,4 @@
## Runner
The runner module aims to help users to start training with less code, while stays
flexible and configurable.

81
docs/utils.md 100644
View File

@ -0,0 +1,81 @@
## Utils
### Config
`Config` class is used for manipulating config and config files. It supports
loading configs from multiple file formats including **python**, **json** and **yaml**.
It provides dict-like apis to get and set values.
Here is an example of the config file `test.py`.
```python
a = 1
b = {'b1': [0, 1, 2], 'b2': None}
c = (1, 2)
d = 'string'
```
To load and use configs
```python
cfg = Config.fromfile('test.py')
assert cfg.a == 1
assert cfg.b.b1 == [0, 1, 2]
cfg.c = None
assert cfg.c == None
```
### ProgressBar
If you want to apply a method to a list of items and track the progress, `track_progress`
is a good choice. It will display a progress bar to tell the progress and ETA.
```python
import mmcv
def func(item):
# do something
pass
tasks = [item_1, item_2, ..., item_n]
mmcv.track_progress(func, tasks)
```
The output is like the following.
![progress](_static/progress.gif)
There is another method `track_parallel_progress`, which wraps multiprocessing and
progress visualization.
```python
mmcv.track_parallel_progress(func, tasks, 8) # 8 workers
```
![progress](_static/parallel_progress.gif)
### Timer
It is convinient to compute the runtime of a code block with `Timer`.
```python
import time
with mmcv.Timer():
# simulate some code block
time.sleep(1)
```
or try with `since_start()` and `since_last_check()`. This former can
return the runtime since the timer starts and the latter will return the time
since the last time checked.
```python
timer = mmcv.Timer()
# code block 1 here
print(timer.since_start())
# code block 2 here
print(timer.since_last_check())
print(timer.since_start())
```

59
docs/video.md 100644
View File

@ -0,0 +1,59 @@
## Video
This module provides the following functionalities.
- A `VideoReader` class with friendly apis to read and convert videos.
- Some methods for editing (cut, concat, resize) videos.
- Optical flow read/write.
The `VideoReader` class provides sequence like apis to access video frames.
It will internally cache the frames which have been visited.
```python
video = mmcv.VideoReader('test.mp4')
# obtain basic information
print(len(video))
print(video.width, video.height, video.resolution, video.fps)
# iterate over all frames
for frame in video:
print(frame.shape)
# read the next frame
img = video.read()
# read a frame by index
img = video[100]
# read some frames
img = video[5:10]
```
To convert a video to images or generate a video from a image directory.
```python
# split a video into frames and save to a folder
video = mmcv.VideoReader('test.mp4')
video.cvt2frames('out_dir')
# generate video from frames
mmcv.frames2video('out_dir', 'test.avi')
```
There are also some methods for editing videos, which wraps the commands of ffmpeg.
```python
# cut a video clip
mmcv.cut_video('test.mp4', 'clip1.mp4', start=3, end=10, vcodec='h264')
# join a list of video clips
mmcv.concat_video(['clip1.mp4', 'clip2.mp4'], 'joined.mp4', log_level='quiet')
# resize a video with the specified size
mmcv.resize_video('test.mp4', 'resized1.mp4', (360, 240))
# resize a video with a scaling ratio of 2
mmcv.resize_video('test.mp4', 'resized2.mp4', ratio=2)
```

View File

@ -0,0 +1,24 @@
## Visualization
`mmcv` can show images and annotations (currently supported types include bounding boxes).
```python
# show an image file
mmcv.imshow('a.jpg')
# show a loaded image
img = np.random.rand(100, 100, 3)
mmcv.imshow(img)
# show image with bounding boxes
img = np.random.rand(100, 100, 3)
bboxes = np.array([[0, 0, 50, 50], [20, 20, 60, 60]])
mmcv.imshow_bboxes(img, bboxes)
```
`mmcv` can also visualize special images such as optical flows.
```python
flow = mmcv.flowread('test.flo')
mmcv.flowshow(flow)
```

View File

@ -109,7 +109,7 @@ def bbox_scaling(bboxes, scale, clip_shape=None):
return scaled_bboxes
def imcrop(img, bboxes, scale_ratio=1.0, pad_fill=None):
def imcrop(img, bboxes, scale=1.0, pad_fill=None):
"""Crop image patches.
3 steps: scale the bboxes -> clip bboxes -> crop and pad.
@ -117,7 +117,7 @@ def imcrop(img, bboxes, scale_ratio=1.0, pad_fill=None):
Args:
img (ndarray): Image to be cropped.
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
scale_ratio (float, optional): Scale ratio of bboxes, the default value
scale (float, optional): Scale ratio of bboxes, the default value
1.0 means no padding.
pad_fill (number or list): Value to be filled for padding, None for
no padding.
@ -132,7 +132,7 @@ def imcrop(img, bboxes, scale_ratio=1.0, pad_fill=None):
assert len(pad_fill) == chn
_bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes
scaled_bboxes = bbox_scaling(_bboxes, scale_ratio).astype(np.int32)
scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32)
clipped_bbox = bbox_clip(scaled_bboxes, img.shape)
patches = []

View File

@ -7,6 +7,7 @@ from .hooks import (Hook, CheckpointHook, ClosureHook, LrUpdaterHook,
from .checkpoint import (load_state_dict, load_checkpoint, weights_to_cpu,
save_checkpoint)
from .parallel import parallel_test, worker_func
from .priority import Priority, get_priority
from .utils import (get_host_info, get_dist_info, master_only, get_time_str,
obj_from_dict)
@ -15,6 +16,7 @@ __all__ = [
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook',
'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint',
'parallel_test', 'worker_func', 'get_host_info', 'get_dist_info',
'master_only', 'get_time_str', 'obj_from_dict'
'parallel_test', 'worker_func', 'Priority', 'get_priority',
'get_host_info', 'get_dist_info', 'master_only', 'get_time_str',
'obj_from_dict'
]

View File

@ -4,7 +4,14 @@ from ..hook import Hook
class LoggerHook(Hook):
"""Base class for logger hooks."""
"""Base class for logger hooks.
Args:
interval (int): Logging interval (every k iterations).
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
reset_flag (bool): Whether to clear the output buffer after logging.
"""
__metaclass__ = ABCMeta

View File

@ -145,12 +145,11 @@ class PaviLoggerHook(LoggerHook):
instance_id=None,
config_file=None,
interval=10,
reset_meter=True,
ignore_last=True):
ignore_last=True,
reset_flag=True):
self.pavi = PaviClient(url, username, password, instance_id)
self.config_file = config_file
super(PaviLoggerHook, self).__init__(interval, reset_meter,
ignore_last)
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag)
def before_run(self, runner):
super(PaviLoggerHook, self).before_run(runner)

View File

@ -4,13 +4,10 @@ from ...utils import master_only
class TensorboardLoggerHook(LoggerHook):
def __init__(self,
log_dir,
interval=10,
reset_meter=True,
ignore_last=True):
super(TensorboardLoggerHook, self).__init__(interval, reset_meter,
ignore_last)
def __init__(self, log_dir, interval=10, ignore_last=True,
reset_flag=True):
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
reset_flag)
self.log_dir = log_dir
@master_only

View File

@ -0,0 +1,53 @@
from enum import Enum
class Priority(Enum):
"""Hook priority levels.
+------------+------------+
| Level | Value |
+============+============+
| HIGHEST | 0 |
+------------+------------+
| VERY_HIGH | 10 |
+------------+------------+
| HIGH | 30 |
+------------+------------+
| NORMAL | 50 |
+------------+------------+
| LOW | 70 |
+------------+------------+
| VERY_LOW | 90 |
+------------+------------+
| LOWEST | 100 |
+------------+------------+
"""
HIGHEST = 0
VERY_HIGH = 10
HIGH = 30
NORMAL = 50
LOW = 70
VERY_LOW = 90
LOWEST = 100
def get_priority(priority):
"""Get priority value.
Args:
priority (int or str or :obj:`Priority`): Priority.
Returns:
int: The priority value.
"""
if isinstance(priority, int):
if priority < 0 or priority > 100:
raise ValueError('priority must be between 0 and 100')
return priority
elif isinstance(priority, Priority):
return priority.value
elif isinstance(priority, str):
return Priority[priority.upper()].value
else:
raise TypeError('priority must be an integer or Priority enum value')

View File

@ -10,11 +10,24 @@ from .log_buffer import LogBuffer
from .hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook,
OptimizerHook, lr_updater)
from .checkpoint import load_checkpoint, save_checkpoint
from .priority import get_priority
from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict
class Runner(object):
"""A training helper for PyTorch."""
"""A training helper for PyTorch.
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): If it is a dict,
runner will construct an optimizer according to it.
work_dir (str, optional): The working directory to save checkpoints
and logs.
log_level (int): Logging level.
"""
def __init__(self,
model,
@ -154,8 +167,8 @@ class Runner(object):
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=level)
logger = logging.getLogger(__name__)
if log_dir:
filename = '{}_{}.log'.format(get_time_str(), self.rank)
if log_dir and self.rank == 0:
filename = '{}.log'.format(get_time_str())
log_file = osp.join(log_dir, filename)
self._add_file_handler(logger, log_file, level=level)
return logger
@ -171,17 +184,18 @@ class Runner(object):
'lr is not applicable because optimizer does not exist.')
return [group['lr'] for group in self.optimizer.param_groups]
def register_hook(self, hook, priority=50):
def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int): Hook priority. Lower value means higher priority.
priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
"""
assert isinstance(hook, Hook)
assert isinstance(priority, int) and priority >= 0 and priority <= 100
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority
# insert the hook to a sorted list
inserted = False
@ -292,6 +306,17 @@ class Runner(object):
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
@ -346,7 +371,7 @@ class Runner(object):
for info in log_config['hooks']:
logger_hook = obj_from_dict(
info, hooks, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=60)
self.register_hook(logger_hook, priority='VERY_LOW')
def register_training_hooks(self,
lr_config,
@ -356,11 +381,12 @@ class Runner(object):
"""Register default hooks for training.
Default hooks include:
- LrUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook
- LoggerHook(s)
"""
if optimizer_config is None:
optimizer_config = {}