mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Docs] Refine API reference. (#774)
* [Docs] Refine API reference * Add PoolFormer
This commit is contained in:
parent
29b882d4eb
commit
702c196514
2
.gitignore
vendored
2
.gitignore
vendored
@ -69,10 +69,12 @@ docs/en/_build/
|
|||||||
docs/en/_model_zoo.rst
|
docs/en/_model_zoo.rst
|
||||||
docs/en/modelzoo_statistics.md
|
docs/en/modelzoo_statistics.md
|
||||||
docs/en/papers/
|
docs/en/papers/
|
||||||
|
docs/en/api/generated/
|
||||||
docs/zh_CN/_build/
|
docs/zh_CN/_build/
|
||||||
docs/zh_CN/_model_zoo.rst
|
docs/zh_CN/_model_zoo.rst
|
||||||
docs/zh_CN/modelzoo_statistics.md
|
docs/zh_CN/modelzoo_statistics.md
|
||||||
docs/zh_CN/papers/
|
docs/zh_CN/papers/
|
||||||
|
docs/zh_CN/api/generated/
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
target/
|
target/
|
||||||
|
@ -14,3 +14,7 @@ article.pytorch-article .section :not(dt) > code {
|
|||||||
background-color: #f3f4f7;
|
background-color: #f3f4f7;
|
||||||
border-radius: 5px;
|
border-radius: 5px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table.colwidths-auto td {
|
||||||
|
width: 50%
|
||||||
|
}
|
||||||
|
14
docs/en/_templates/classtemplate.rst
Normal file
14
docs/en/_templates/classtemplate.rst
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
.. currentmodule:: {{ module }}
|
||||||
|
|
||||||
|
|
||||||
|
{{ name | underline}}
|
||||||
|
|
||||||
|
.. autoclass:: {{ name }}
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
..
|
||||||
|
autogenerated from source/_templates/classtemplate.rst
|
||||||
|
note it does not have :inherited-members:
|
@ -1,68 +0,0 @@
|
|||||||
mmcls.apis
|
|
||||||
-------------
|
|
||||||
.. automodule:: mmcls.apis
|
|
||||||
:members:
|
|
||||||
|
|
||||||
mmcls.core
|
|
||||||
-------------
|
|
||||||
|
|
||||||
evaluation
|
|
||||||
^^^^^^^^^^
|
|
||||||
.. automodule:: mmcls.core.evaluation
|
|
||||||
:members:
|
|
||||||
|
|
||||||
mmcls.models
|
|
||||||
---------------
|
|
||||||
|
|
||||||
models
|
|
||||||
^^^^^^
|
|
||||||
.. automodule:: mmcls.models
|
|
||||||
:members:
|
|
||||||
|
|
||||||
classifiers
|
|
||||||
^^^^^^^^^^^
|
|
||||||
.. automodule:: mmcls.models.classifiers
|
|
||||||
:members:
|
|
||||||
|
|
||||||
backbones
|
|
||||||
^^^^^^^^^^
|
|
||||||
.. automodule:: mmcls.models.backbones
|
|
||||||
:members:
|
|
||||||
|
|
||||||
necks
|
|
||||||
^^^^^^
|
|
||||||
.. automodule:: mmcls.models.necks
|
|
||||||
:members:
|
|
||||||
|
|
||||||
heads
|
|
||||||
^^^^^^
|
|
||||||
.. automodule:: mmcls.models.heads
|
|
||||||
:members:
|
|
||||||
|
|
||||||
losses
|
|
||||||
^^^^^^
|
|
||||||
.. automodule:: mmcls.models.losses
|
|
||||||
:members:
|
|
||||||
|
|
||||||
utils
|
|
||||||
^^^^^^
|
|
||||||
.. automodule:: mmcls.models.utils
|
|
||||||
:members:
|
|
||||||
|
|
||||||
mmcls.datasets
|
|
||||||
-----------------
|
|
||||||
|
|
||||||
datasets
|
|
||||||
^^^^^^^^
|
|
||||||
.. automodule:: mmcls.datasets
|
|
||||||
:members:
|
|
||||||
|
|
||||||
pipelines
|
|
||||||
^^^^^^^^^
|
|
||||||
.. automodule:: mmcls.datasets.pipelines
|
|
||||||
:members:
|
|
||||||
|
|
||||||
mmcls.utils
|
|
||||||
--------------
|
|
||||||
.. automodule:: mmcls.utils
|
|
||||||
:members:
|
|
45
docs/en/api/apis.rst
Normal file
45
docs/en/api/apis.rst
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
mmcls.apis
|
||||||
|
===================================
|
||||||
|
|
||||||
|
These are some high-level APIs for classification tasks.
|
||||||
|
|
||||||
|
.. contents:: mmcls.apis
|
||||||
|
:depth: 2
|
||||||
|
:local:
|
||||||
|
:backlinks: top
|
||||||
|
|
||||||
|
.. currentmodule:: mmcls.apis
|
||||||
|
|
||||||
|
Train
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
init_random_seed
|
||||||
|
set_random_seed
|
||||||
|
train_model
|
||||||
|
|
||||||
|
Test
|
||||||
|
------------------
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
single_gpu_test
|
||||||
|
multi_gpu_test
|
||||||
|
|
||||||
|
Inference
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
init_model
|
||||||
|
inference_model
|
||||||
|
show_result_pyplot
|
61
docs/en/api/core.rst
Normal file
61
docs/en/api/core.rst
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
mmcls.core
|
||||||
|
===================================
|
||||||
|
|
||||||
|
This package includes some runtime components. These components are useful in
|
||||||
|
classification tasks but not supported by MMCV yet.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Some components may be moved to MMCV in the future.
|
||||||
|
|
||||||
|
.. contents:: mmcls.core
|
||||||
|
:depth: 2
|
||||||
|
:local:
|
||||||
|
:backlinks: top
|
||||||
|
|
||||||
|
.. currentmodule:: mmcls.core
|
||||||
|
|
||||||
|
Evaluation
|
||||||
|
------------------
|
||||||
|
|
||||||
|
Evaluation metrics calculation functions
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
precision
|
||||||
|
recall
|
||||||
|
f1_score
|
||||||
|
precision_recall_f1
|
||||||
|
average_precision
|
||||||
|
mAP
|
||||||
|
support
|
||||||
|
average_performance
|
||||||
|
calculate_confusion_matrix
|
||||||
|
|
||||||
|
Hook
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
ClassNumCheckHook
|
||||||
|
PreciseBNHook
|
||||||
|
CosineAnnealingCooldownLrUpdaterHook
|
||||||
|
|
||||||
|
|
||||||
|
Optimizers
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
Lamb
|
56
docs/en/api/datasets.rst
Normal file
56
docs/en/api/datasets.rst
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
mmcls.datasets
|
||||||
|
===================================
|
||||||
|
|
||||||
|
The ``datasets`` package contains several usual datasets for image classification tasks and some dataset wrappers.
|
||||||
|
|
||||||
|
.. currentmodule:: mmcls.datasets
|
||||||
|
|
||||||
|
Custom Dataset
|
||||||
|
--------------
|
||||||
|
|
||||||
|
.. autoclass:: CustomDataset
|
||||||
|
|
||||||
|
ImageNet
|
||||||
|
--------
|
||||||
|
|
||||||
|
.. autoclass:: ImageNet
|
||||||
|
|
||||||
|
.. autoclass:: ImageNet21k
|
||||||
|
|
||||||
|
CIFAR
|
||||||
|
-----
|
||||||
|
|
||||||
|
.. autoclass:: CIFAR10
|
||||||
|
|
||||||
|
.. autoclass:: CIFAR100
|
||||||
|
|
||||||
|
MNIST
|
||||||
|
-----
|
||||||
|
|
||||||
|
.. autoclass:: MNIST
|
||||||
|
|
||||||
|
.. autoclass:: FashionMNIST
|
||||||
|
|
||||||
|
VOC
|
||||||
|
---
|
||||||
|
|
||||||
|
.. autoclass:: VOC
|
||||||
|
|
||||||
|
Base classes
|
||||||
|
------------
|
||||||
|
|
||||||
|
.. autoclass:: BaseDataset
|
||||||
|
|
||||||
|
.. autoclass:: MultiLabelDataset
|
||||||
|
|
||||||
|
Dataset Wrappers
|
||||||
|
----------------
|
||||||
|
|
||||||
|
.. autoclass:: ConcatDataset
|
||||||
|
|
||||||
|
.. autoclass:: RepeatDataset
|
||||||
|
|
||||||
|
.. autoclass:: ClassBalancedDataset
|
137
docs/en/api/models.rst
Normal file
137
docs/en/api/models.rst
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
mmcls.models
|
||||||
|
===================================
|
||||||
|
|
||||||
|
The ``models`` package contains several sub-packages for addressing the different components of a model.
|
||||||
|
|
||||||
|
- :ref:`classifiers`: The top-level module which defines the whole process of a classification model.
|
||||||
|
- :ref:`backbones`: Usually a feature extraction network, e.g., ResNet, MobileNet.
|
||||||
|
- :ref:`necks`: The component between backbones and heads, e.g., GlobalAveragePooling.
|
||||||
|
- :ref:`heads`: The component for specific tasks. In MMClassification, we provides heads for classification.
|
||||||
|
- :ref:`losses`: Loss functions.
|
||||||
|
|
||||||
|
.. currentmodule:: mmcls.models
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
build_classifier
|
||||||
|
build_backbone
|
||||||
|
build_neck
|
||||||
|
build_head
|
||||||
|
build_loss
|
||||||
|
|
||||||
|
.. _classifiers:
|
||||||
|
|
||||||
|
Classifier
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
BaseClassifier
|
||||||
|
ImageClassifier
|
||||||
|
|
||||||
|
.. _backbones:
|
||||||
|
|
||||||
|
Backbones
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
AlexNet
|
||||||
|
CSPDarkNet
|
||||||
|
CSPNet
|
||||||
|
CSPResNeXt
|
||||||
|
CSPResNet
|
||||||
|
Conformer
|
||||||
|
ConvMixer
|
||||||
|
ConvNeXt
|
||||||
|
DistilledVisionTransformer
|
||||||
|
EfficientNet
|
||||||
|
HRNet
|
||||||
|
LeNet5
|
||||||
|
MlpMixer
|
||||||
|
MobileNetV2
|
||||||
|
MobileNetV3
|
||||||
|
PCPVT
|
||||||
|
PoolFormer
|
||||||
|
RegNet
|
||||||
|
RepMLPNet
|
||||||
|
RepVGG
|
||||||
|
Res2Net
|
||||||
|
ResNeSt
|
||||||
|
ResNeXt
|
||||||
|
ResNet
|
||||||
|
ResNetV1c
|
||||||
|
ResNetV1d
|
||||||
|
ResNet_CIFAR
|
||||||
|
SEResNeXt
|
||||||
|
SEResNet
|
||||||
|
SVT
|
||||||
|
ShuffleNetV1
|
||||||
|
ShuffleNetV2
|
||||||
|
SwinTransformer
|
||||||
|
T2T_ViT
|
||||||
|
TIMMBackbone
|
||||||
|
TNT
|
||||||
|
VGG
|
||||||
|
VisionTransformer
|
||||||
|
|
||||||
|
.. _necks:
|
||||||
|
|
||||||
|
Necks
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
GlobalAveragePooling
|
||||||
|
GeneralizedMeanPooling
|
||||||
|
HRFuseScales
|
||||||
|
|
||||||
|
.. _heads:
|
||||||
|
|
||||||
|
Heads
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
ClsHead
|
||||||
|
LinearClsHead
|
||||||
|
StackedLinearClsHead
|
||||||
|
MultiLabelClsHead
|
||||||
|
MultiLabelLinearClsHead
|
||||||
|
VisionTransformerClsHead
|
||||||
|
DeiTClsHead
|
||||||
|
ConformerHead
|
||||||
|
|
||||||
|
.. _losses:
|
||||||
|
|
||||||
|
Losses
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
Accuracy
|
||||||
|
AsymmetricLoss
|
||||||
|
CrossEntropyLoss
|
||||||
|
LabelSmoothLoss
|
||||||
|
FocalLoss
|
||||||
|
SeesawLoss
|
35
docs/en/api/models.utils.augment.rst
Normal file
35
docs/en/api/models.utils.augment.rst
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
Batch Augmentation
|
||||||
|
===================================
|
||||||
|
|
||||||
|
Batch augmentation is the augmentation which involve multiple samples, such as Mixup and CutMix.
|
||||||
|
|
||||||
|
In MMClassification, these batch augmentation is used as a part of :ref:`classifiers`. A typical usage is as below:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
backbone = ...,
|
||||||
|
neck = ...,
|
||||||
|
head = ...,
|
||||||
|
train_cfg=dict(augments=[
|
||||||
|
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
|
||||||
|
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
|
||||||
|
]))
|
||||||
|
)
|
||||||
|
|
||||||
|
.. currentmodule:: mmcls.models.utils.augment
|
||||||
|
|
||||||
|
Mixup
|
||||||
|
-----
|
||||||
|
.. autoclass:: BatchMixupLayer
|
||||||
|
|
||||||
|
CutMix
|
||||||
|
------
|
||||||
|
.. autoclass:: BatchCutMixLayer
|
||||||
|
|
||||||
|
ResizeMix
|
||||||
|
---------
|
||||||
|
.. autoclass:: BatchResizeMixLayer
|
50
docs/en/api/models.utils.rst
Normal file
50
docs/en/api/models.utils.rst
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
mmcls.models.utils
|
||||||
|
===================================
|
||||||
|
|
||||||
|
This package includes some helper functions and common components used in various networks.
|
||||||
|
|
||||||
|
.. contents:: mmcls.models.utils
|
||||||
|
:depth: 2
|
||||||
|
:local:
|
||||||
|
:backlinks: top
|
||||||
|
|
||||||
|
.. currentmodule:: mmcls.models.utils
|
||||||
|
|
||||||
|
Common Components
|
||||||
|
------------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
InvertedResidual
|
||||||
|
SELayer
|
||||||
|
ShiftWindowMSA
|
||||||
|
MultiheadAttention
|
||||||
|
ConditionalPositionEncoding
|
||||||
|
|
||||||
|
Helper Functions
|
||||||
|
------------------
|
||||||
|
|
||||||
|
channel_shuffle
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
.. autofunction:: channel_shuffle
|
||||||
|
|
||||||
|
make_divisible
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
.. autofunction:: make_divisible
|
||||||
|
|
||||||
|
to_ntuple
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
.. autofunction:: to_ntuple
|
||||||
|
.. autofunction:: to_2tuple
|
||||||
|
.. autofunction:: to_3tuple
|
||||||
|
.. autofunction:: to_4tuple
|
||||||
|
|
||||||
|
is_tracing
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
.. autofunction:: is_tracing
|
171
docs/en/api/transforms.rst
Normal file
171
docs/en/api/transforms.rst
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
Data Transformations
|
||||||
|
***********************************
|
||||||
|
|
||||||
|
In MMClassification, the data preparation and the dataset is decomposed. The
|
||||||
|
datasets only define how to get samples' basic information from the file
|
||||||
|
system. These basic information includes the ground-truth label and raw images
|
||||||
|
data / the paths of images.
|
||||||
|
|
||||||
|
To prepare the inputs data, we need to do some transformations on these basic
|
||||||
|
information. These transformations includes loading, preprocessing and
|
||||||
|
formatting. And a series of data transformations makes up a data pipeline.
|
||||||
|
Therefore, you can find the a ``pipeline`` argument in the configs of dataset,
|
||||||
|
for example:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='RandomResizedCrop', size=224),
|
||||||
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='ToTensor', keys=['gt_label']),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_label'])
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='Resize', size=256),
|
||||||
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img'])
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
train=dict(..., pipeline=train_pipeline),
|
||||||
|
val=dict(..., pipeline=test_pipeline),
|
||||||
|
test=dict(..., pipeline=test_pipeline),
|
||||||
|
)
|
||||||
|
|
||||||
|
Every item of a pipeline list is one of the following data transformations class. And if you want to add a custom data transformation class, the tutorial :doc:`Custom Data Pipelines </tutorials/data_pipeline>` will help you.
|
||||||
|
|
||||||
|
.. contents:: mmcls.datasets.pipelines
|
||||||
|
:depth: 2
|
||||||
|
:local:
|
||||||
|
:backlinks: top
|
||||||
|
|
||||||
|
.. currentmodule:: mmcls.datasets.pipelines
|
||||||
|
|
||||||
|
Loading
|
||||||
|
=======
|
||||||
|
|
||||||
|
LoadImageFromFile
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: LoadImageFromFile
|
||||||
|
|
||||||
|
Preprocessing and Augmentation
|
||||||
|
==============================
|
||||||
|
|
||||||
|
CenterCrop
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: CenterCrop
|
||||||
|
|
||||||
|
Lighting
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: Lighting
|
||||||
|
|
||||||
|
Normalize
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: Normalize
|
||||||
|
|
||||||
|
Pad
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: Pad
|
||||||
|
|
||||||
|
Resize
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: Resize
|
||||||
|
|
||||||
|
RandomCrop
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: RandomCrop
|
||||||
|
|
||||||
|
RandomErasing
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: RandomErasing
|
||||||
|
|
||||||
|
RandomFlip
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: RandomFlip
|
||||||
|
|
||||||
|
RandomGrayscale
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: RandomGrayscale
|
||||||
|
|
||||||
|
RandomResizedCrop
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: RandomResizedCrop
|
||||||
|
|
||||||
|
ColorJitter
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: ColorJitter
|
||||||
|
|
||||||
|
|
||||||
|
Composed Augmentation
|
||||||
|
---------------------
|
||||||
|
Composed augmentation is a kind of methods which compose a series of data
|
||||||
|
augmentation transformations, such as ``AutoAugment`` and ``RandAugment``.
|
||||||
|
|
||||||
|
.. autoclass:: AutoAugment
|
||||||
|
|
||||||
|
.. autoclass:: RandAugment
|
||||||
|
|
||||||
|
In composed augmentation, we need to specify several data transformations or
|
||||||
|
several groups of data transformations (The ``policies`` argument) as the
|
||||||
|
random sampling space. These data transformations are chosen from the below
|
||||||
|
table. In addition, we provide some preset policies in `this folder`_.
|
||||||
|
|
||||||
|
.. _this folder: https://github.com/open-mmlab/mmclassification/tree/master/configs/_base_/datasets/pipelines
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
AutoContrast
|
||||||
|
Brightness
|
||||||
|
ColorTransform
|
||||||
|
Contrast
|
||||||
|
Cutout
|
||||||
|
Equalize
|
||||||
|
Invert
|
||||||
|
Posterize
|
||||||
|
Rotate
|
||||||
|
Sharpness
|
||||||
|
Shear
|
||||||
|
Solarize
|
||||||
|
SolarizeAdd
|
||||||
|
Translate
|
||||||
|
|
||||||
|
Formatting
|
||||||
|
==========
|
||||||
|
|
||||||
|
Collect
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: Collect
|
||||||
|
|
||||||
|
ImageToTensor
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: ImageToTensor
|
||||||
|
|
||||||
|
ToNumpy
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: ToNumpy
|
||||||
|
|
||||||
|
ToPIL
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: ToPIL
|
||||||
|
|
||||||
|
ToTensor
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: ToTensor
|
||||||
|
|
||||||
|
Transpose
|
||||||
|
---------------------
|
||||||
|
.. autoclass:: Transpose
|
23
docs/en/api/utils.rst
Normal file
23
docs/en/api/utils.rst
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
|
||||||
|
mmcls.utils
|
||||||
|
===================================
|
||||||
|
|
||||||
|
These are some useful help function in the ``utils`` package.
|
||||||
|
|
||||||
|
.. contents:: mmcls.utils
|
||||||
|
:depth: 1
|
||||||
|
:local:
|
||||||
|
:backlinks: top
|
||||||
|
|
||||||
|
.. currentmodule:: mmcls.utils
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
collect_env
|
||||||
|
get_root_logger
|
||||||
|
load_json_log
|
||||||
|
setup_multi_processes
|
@ -44,6 +44,8 @@ release = get_version()
|
|||||||
# ones.
|
# ones.
|
||||||
extensions = [
|
extensions = [
|
||||||
'sphinx.ext.autodoc',
|
'sphinx.ext.autodoc',
|
||||||
|
'sphinx.ext.autosummary',
|
||||||
|
'sphinx.ext.intersphinx',
|
||||||
'sphinx.ext.napoleon',
|
'sphinx.ext.napoleon',
|
||||||
'sphinx.ext.viewcode',
|
'sphinx.ext.viewcode',
|
||||||
'sphinx_markdown_tables',
|
'sphinx_markdown_tables',
|
||||||
@ -218,6 +220,15 @@ StandaloneHTMLBuilder.supported_image_types = [
|
|||||||
# Ignore >>> when copying code
|
# Ignore >>> when copying code
|
||||||
copybutton_prompt_text = r'>>> |\.\.\. '
|
copybutton_prompt_text = r'>>> |\.\.\. '
|
||||||
copybutton_prompt_is_regexp = True
|
copybutton_prompt_is_regexp = True
|
||||||
|
# Auto-generated header anchors
|
||||||
|
myst_heading_anchors = 3
|
||||||
|
# Configuration for intersphinx
|
||||||
|
intersphinx_mapping = {
|
||||||
|
'python': ('https://docs.python.org/3', None),
|
||||||
|
'numpy': ('https://numpy.org/doc/stable', None),
|
||||||
|
'torch': ('https://pytorch.org/docs/stable/', None),
|
||||||
|
'mmcv': ('https://mmcv.readthedocs.io/en/master/', None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def builder_inited_handler(app):
|
def builder_inited_handler(app):
|
||||||
|
2
docs/en/docutils.conf
Normal file
2
docs/en/docutils.conf
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[html writers]
|
||||||
|
table_style: colwidths-auto
|
@ -57,9 +57,17 @@ You can switch between Chinese and English documentation in the lower-left corne
|
|||||||
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
:caption: API Reference
|
:caption: API Reference
|
||||||
|
|
||||||
api.rst
|
mmcls.apis <api/apis>
|
||||||
|
mmcls.core <api/core>
|
||||||
|
mmcls.models <api/models>
|
||||||
|
mmcls.models.utils <api/models.utils>
|
||||||
|
mmcls.datasets <api/datasets>
|
||||||
|
Data Transformations <api/transforms>
|
||||||
|
Batch Augmentation <api/models.utils.augment>
|
||||||
|
mmcls.utils <api/utils>
|
||||||
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
@ -13,7 +13,7 @@ If you wish to inspect the config file, you may run `python tools/misc/print_con
|
|||||||
- [Ignore some fields in the base configs](#ignore-some-fields-in-the-base-configs)
|
- [Ignore some fields in the base configs](#ignore-some-fields-in-the-base-configs)
|
||||||
- [Use some fields in the base configs](#use-some-fields-in-the-base-configs)
|
- [Use some fields in the base configs](#use-some-fields-in-the-base-configs)
|
||||||
- [Modify config through script arguments](#modify-config-through-script-arguments)
|
- [Modify config through script arguments](#modify-config-through-script-arguments)
|
||||||
- [Import user-defined modules](#import-ser-defined-modules)
|
- [Import user-defined modules](#import-user-defined-modules)
|
||||||
- [FAQ](#faq)
|
- [FAQ](#faq)
|
||||||
|
|
||||||
<!-- TOC -->
|
<!-- TOC -->
|
||||||
|
@ -98,6 +98,8 @@ More supported backends can be found in [mmcv.fileio.FileClient](https://github.
|
|||||||
|
|
||||||
- remove: all other keys except for those specified by `keys`
|
- remove: all other keys except for those specified by `keys`
|
||||||
|
|
||||||
|
For more information about other data transformation classes, please refer to [Data Transformations](../api/transforms.rst)
|
||||||
|
|
||||||
## Extend and use custom pipelines
|
## Extend and use custom pipelines
|
||||||
|
|
||||||
1. Write a new pipeline in any file, e.g., `my_pipeline.py`, and place it in
|
1. Write a new pipeline in any file, e.g., `my_pipeline.py`, and place it in
|
||||||
|
@ -7,14 +7,8 @@ In this tutorial, we will introduce some methods about how to customize workflow
|
|||||||
- [Customize Workflow](#customize-workflow)
|
- [Customize Workflow](#customize-workflow)
|
||||||
- [Hooks](#hooks)
|
- [Hooks](#hooks)
|
||||||
- [Default training hooks](#default-training-hooks)
|
- [Default training hooks](#default-training-hooks)
|
||||||
- [CheckpointHook](#checkpointhook)
|
|
||||||
- [LoggerHooks](#loggerhooks)
|
|
||||||
- [EvalHook](#evalhook)
|
|
||||||
- [Use other implemented hooks](#use-other-implemented-hooks)
|
- [Use other implemented hooks](#use-other-implemented-hooks)
|
||||||
- [Customize self-implemented hooks](#customize-self-implemented-hooks)
|
- [Customize self-implemented hooks](#customize-self-implemented-hooks)
|
||||||
- [1. Implement a new hook](#1.-implement-a-new-hook)
|
|
||||||
- [2. Register the new hook](#2.-register-the-new-hook)
|
|
||||||
- [3. Modify the config](#3.-modify-the-config)
|
|
||||||
- [FAQ](#faq)
|
- [FAQ](#faq)
|
||||||
|
|
||||||
<!-- TOC -->
|
<!-- TOC -->
|
||||||
|
@ -18,6 +18,21 @@ def single_gpu_test(model,
|
|||||||
show=False,
|
show=False,
|
||||||
out_dir=None,
|
out_dir=None,
|
||||||
**show_kwargs):
|
**show_kwargs):
|
||||||
|
"""Test model with local single gpu.
|
||||||
|
|
||||||
|
This method tests model with a single gpu and supports showing results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`torch.nn.Module`): Model to be tested.
|
||||||
|
data_loader (:obj:`torch.utils.data.DataLoader`): Pytorch data loader.
|
||||||
|
show (bool): Whether to show the test results. Defaults to False.
|
||||||
|
out_dir (str): The output directory of result plots of all samples.
|
||||||
|
Defaults to None, which means not to write output files.
|
||||||
|
**show_kwargs: Any other keyword arguments for showing results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The prediction results.
|
||||||
|
"""
|
||||||
model.eval()
|
model.eval()
|
||||||
results = []
|
results = []
|
||||||
dataset = data_loader.dataset
|
dataset = data_loader.dataset
|
||||||
|
@ -75,6 +75,28 @@ def train_model(model,
|
|||||||
timestamp=None,
|
timestamp=None,
|
||||||
device=None,
|
device=None,
|
||||||
meta=None):
|
meta=None):
|
||||||
|
"""Train a model.
|
||||||
|
|
||||||
|
This method will build dataloaders, wrap the model and build a runner
|
||||||
|
according to the provided config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`torch.nn.Module`): The model to be run.
|
||||||
|
dataset (:obj:`mmcls.datasets.BaseDataset` | List[BaseDataset]):
|
||||||
|
The dataset used to train the model. It can be a single dataset,
|
||||||
|
or a list of dataset with the same length as workflow.
|
||||||
|
cfg (:obj:`mmcv.utils.Config`): The configs of the experiment.
|
||||||
|
distributed (bool): Whether to train the model in a distributed
|
||||||
|
environment. Defaults to False.
|
||||||
|
validate (bool): Whether to do validation with
|
||||||
|
:obj:`mmcv.runner.EvalHook`. Defaults to False.
|
||||||
|
timestamp (str, optional): The timestamp string to auto generate the
|
||||||
|
name of log files. Defaults to None.
|
||||||
|
device (str, optional): TODO
|
||||||
|
meta (dict, optional): A dict records some import information such as
|
||||||
|
environment info and seed, which will be logged in logger hook.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
|
|
||||||
# prepare data loaders
|
# prepare data loaders
|
||||||
|
@ -65,13 +65,15 @@ from torch.optim import Optimizer
|
|||||||
|
|
||||||
@OPTIMIZERS.register_module()
|
@OPTIMIZERS.register_module()
|
||||||
class Lamb(Optimizer):
|
class Lamb(Optimizer):
|
||||||
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer
|
"""A pure pytorch variant of FuseLAMB (NvLamb variant) optimizer.
|
||||||
from apex.optimizers.FusedLAMB
|
|
||||||
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/
|
|
||||||
PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
|
||||||
|
|
||||||
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training
|
This class is copied from `timm`_. The LAMB was proposed in `Large Batch
|
||||||
BERT in 76 minutes`_.
|
Optimization for Deep Learning - Training BERT in 76 minutes`_.
|
||||||
|
|
||||||
|
.. _timm:
|
||||||
|
https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py
|
||||||
|
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
|
||||||
|
https://arxiv.org/abs/1904.00962
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
params (iterable): iterable of parameters to optimize or dicts defining
|
params (iterable): iterable of parameters to optimize or dicts defining
|
||||||
@ -89,13 +91,7 @@ class Lamb(Optimizer):
|
|||||||
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
|
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
|
||||||
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
|
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
|
||||||
weight decay parameter (default: False)
|
weight decay parameter (default: False)
|
||||||
|
""" # noqa: E501
|
||||||
.. _Large Batch Optimization for Deep Learning - Training BERT in 76
|
|
||||||
minutes:
|
|
||||||
https://arxiv.org/abs/1904.00962
|
|
||||||
.. _On the Convergence of Adam and Beyond:
|
|
||||||
https://openreview.net/forum?id=ryQu7f-RZ
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
params,
|
params,
|
||||||
|
@ -18,7 +18,7 @@ class ConcatDataset(_ConcatDataset):
|
|||||||
add `get_cat_ids` function.
|
add `get_cat_ids` function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
datasets (list[:obj:`Dataset`]): A list of datasets.
|
datasets (list[:obj:`BaseDataset`]): A list of datasets.
|
||||||
separate_eval (bool): Whether to evaluate the results
|
separate_eval (bool): Whether to evaluate the results
|
||||||
separately if it is used as validation dataset.
|
separately if it is used as validation dataset.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
@ -117,7 +117,7 @@ class RepeatDataset(object):
|
|||||||
epochs.
|
epochs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (:obj:`Dataset`): The dataset to be repeated.
|
dataset (:obj:`BaseDataset`): The dataset to be repeated.
|
||||||
times (int): Repeat times.
|
times (int): Repeat times.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -157,9 +157,11 @@ class RepeatDataset(object):
|
|||||||
class ClassBalancedDataset(object):
|
class ClassBalancedDataset(object):
|
||||||
r"""A wrapper of repeated dataset with repeat factor.
|
r"""A wrapper of repeated dataset with repeat factor.
|
||||||
|
|
||||||
Suitable for training on class imbalanced datasets like LVIS. Following
|
Suitable for training on class imbalanced datasets like LVIS. Following the
|
||||||
the sampling strategy in [#1]_, in each epoch, an image may appear multiple
|
sampling strategy in `this paper`_, in each epoch, an image may appear
|
||||||
times based on its "repeat factor".
|
multiple times based on its "repeat factor".
|
||||||
|
|
||||||
|
.. _this paper: https://arxiv.org/pdf/1908.03195.pdf
|
||||||
|
|
||||||
The repeat factor for an image is a function of the frequency the rarest
|
The repeat factor for an image is a function of the frequency the rarest
|
||||||
category labeled in that image. The "frequency of category c" in [0, 1]
|
category labeled in that image. The "frequency of category c" in [0, 1]
|
||||||
@ -184,16 +186,13 @@ class ClassBalancedDataset(object):
|
|||||||
.. math::
|
.. math::
|
||||||
r(I) = \max_{c \in L(I)} r(c)
|
r(I) = \max_{c \in L(I)} r(c)
|
||||||
|
|
||||||
References:
|
|
||||||
.. [#1] https://arxiv.org/pdf/1908.03195.pdf
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (:obj:`CustomDataset`): The dataset to be repeated.
|
dataset (:obj:`BaseDataset`): The dataset to be repeated.
|
||||||
oversample_thr (float): frequency threshold below which data is
|
oversample_thr (float): frequency threshold below which data is
|
||||||
repeated. For categories with `f_c` >= `oversample_thr`, there is
|
repeated. For categories with ``f_c`` >= ``oversample_thr``, there
|
||||||
no oversampling. For categories with `f_c` < `oversample_thr`, the
|
is no oversampling. For categories with ``f_c`` <
|
||||||
degree of oversampling following the square-root inverse frequency
|
``oversample_thr``, the degree of oversampling following the
|
||||||
heuristic above.
|
square-root inverse frequency heuristic above.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset, oversample_thr):
|
def __init__(self, dataset, oversample_thr):
|
||||||
@ -278,7 +277,7 @@ class KFoldDataset:
|
|||||||
and use the fold left to do validation.
|
and use the fold left to do validation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (:obj:`CustomDataset`): The dataset to be divided.
|
dataset (:obj:`BaseDataset`): The dataset to be divided.
|
||||||
fold (int): The fold used to do validation. Defaults to 0.
|
fold (int): The fold used to do validation. Defaults to 0.
|
||||||
num_splits (int): The number of all folds. Defaults to 5.
|
num_splits (int): The number of all folds. Defaults to 5.
|
||||||
test_mode (bool): Use the training dataset or validation dataset.
|
test_mode (bool): Use the training dataset or validation dataset.
|
||||||
|
@ -9,8 +9,33 @@ from .custom import CustomDataset
|
|||||||
class ImageNet(CustomDataset):
|
class ImageNet(CustomDataset):
|
||||||
"""`ImageNet <http://www.image-net.org>`_ Dataset.
|
"""`ImageNet <http://www.image-net.org>`_ Dataset.
|
||||||
|
|
||||||
This implementation is modified from
|
The dataset supports two kinds of annotation format. More details can be
|
||||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py
|
found in :class:`CustomDataset`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_prefix (str): The path of data directory.
|
||||||
|
pipeline (Sequence[dict]): A list of dict, where each element
|
||||||
|
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
|
||||||
|
Defaults to an empty tuple.
|
||||||
|
classes (str | Sequence[str], optional): Specify names of classes.
|
||||||
|
|
||||||
|
- If is string, it should be a file path, and the every line of
|
||||||
|
the file is a name of a class.
|
||||||
|
- If is a sequence of string, every item is a name of class.
|
||||||
|
- If is None, use the default ImageNet-1k classes names.
|
||||||
|
|
||||||
|
Defaults to None.
|
||||||
|
ann_file (str, optional): The annotation file. If is string, read
|
||||||
|
samples paths from the ann_file. If is None, find samples in
|
||||||
|
``data_prefix``. Defaults to None.
|
||||||
|
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
|
||||||
|
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
|
||||||
|
test_mode (bool): In train mode or test mode. It's only a mark and
|
||||||
|
won't be used in this class. Defaults to False.
|
||||||
|
file_client_args (dict, optional): Arguments to instantiate a
|
||||||
|
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||||
|
If None, automatically inference from the specified path.
|
||||||
|
Defaults to None.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
|
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
|
||||||
|
@ -31,8 +31,8 @@ class ImageNet21k(CustomDataset):
|
|||||||
- If is string, it should be a file path, and the every line of
|
- If is string, it should be a file path, and the every line of
|
||||||
the file is a name of a class.
|
the file is a name of a class.
|
||||||
- If is a sequence of string, every item is a name of class.
|
- If is a sequence of string, every item is a name of class.
|
||||||
- If is None, use ``cls.CLASSES`` or the names of sub folders
|
- If is None, the object won't have category information.
|
||||||
(If use the second way to arrange samples).
|
(Not recommended)
|
||||||
|
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
ann_file (str, optional): The annotation file. If is string, read
|
ann_file (str, optional): The annotation file. If is string, read
|
||||||
|
@ -133,8 +133,8 @@ class RandAugment(object):
|
|||||||
When magnitude_std=0, we calculate the magnitude as follows:
|
When magnitude_std=0, we calculate the magnitude as follows:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{magnitude} = \frac{\text{magnitude\_level}}
|
\text{magnitude} = \frac{\text{magnitude_level}}
|
||||||
{\text{total\_level}} \times (\text{val2} - \text{val1})
|
{\text{totallevel}} \times (\text{val2} - \text{val1})
|
||||||
+ \text{val1}
|
+ \text{val1}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -798,8 +798,8 @@ class CenterCrop(object):
|
|||||||
to perform the center crop with the ``crop_size_`` as:
|
to perform the center crop with the ``crop_size_`` as:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{crop\_size\_} = \frac{\text{crop\_size}}{\text{crop\_size} +
|
\text{crop_size_} = \frac{\text{crop_size}}{\text{crop_size} +
|
||||||
\text{crop\_padding}} \times \text{short\_edge}
|
\text{crop_padding}} \times \text{short_edge}
|
||||||
|
|
||||||
And then the pipeline resizes the img to the input crop size.
|
And then the pipeline resizes the img to the input crop size.
|
||||||
"""
|
"""
|
||||||
|
@ -201,8 +201,9 @@ class PixelEmbed(BaseModule):
|
|||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
class TNT(BaseBackbone):
|
class TNT(BaseBackbone):
|
||||||
""" Transformer in Transformer
|
"""Transformer in Transformer.
|
||||||
A PyTorch implement of : `Transformer in Transformer
|
|
||||||
|
A PyTorch implement of: `Transformer in Transformer
|
||||||
<https://arxiv.org/abs/2103.00112>`_
|
<https://arxiv.org/abs/2103.00112>`_
|
||||||
|
|
||||||
Inspiration from
|
Inspiration from
|
||||||
|
@ -9,6 +9,24 @@ from .vision_transformer_head import VisionTransformerClsHead
|
|||||||
|
|
||||||
@HEADS.register_module()
|
@HEADS.register_module()
|
||||||
class DeiTClsHead(VisionTransformerClsHead):
|
class DeiTClsHead(VisionTransformerClsHead):
|
||||||
|
"""Distilled Vision Transformer classifier head.
|
||||||
|
|
||||||
|
Comparing with the :class:`VisionTransformerClsHead`, this head adds an
|
||||||
|
extra linear layer to handle the dist token. The final classification score
|
||||||
|
is the average of both linear transformation results of ``cls_token`` and
|
||||||
|
``dist_token``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of categories excluding the background
|
||||||
|
category.
|
||||||
|
in_channels (int): Number of channels in the input feature map.
|
||||||
|
hidden_dim (int): Number of the dimensions for hidden layer.
|
||||||
|
Defaults to None, which means no extra hidden layer.
|
||||||
|
act_cfg (dict): The activation config. Only available during
|
||||||
|
pre-training. Defaults to ``dict(type='Tanh')``.
|
||||||
|
init_cfg (dict): The extra initialization configs. Defaults to
|
||||||
|
``dict(type='Constant', layer='Linear', val=0)``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(DeiTClsHead, self).__init__(*args, **kwargs)
|
super(DeiTClsHead, self).__init__(*args, **kwargs)
|
||||||
|
@ -20,10 +20,12 @@ class VisionTransformerClsHead(ClsHead):
|
|||||||
num_classes (int): Number of categories excluding the background
|
num_classes (int): Number of categories excluding the background
|
||||||
category.
|
category.
|
||||||
in_channels (int): Number of channels in the input feature map.
|
in_channels (int): Number of channels in the input feature map.
|
||||||
hidden_dim (int): Number of the dimensions for hidden layer. Only
|
hidden_dim (int): Number of the dimensions for hidden layer.
|
||||||
available during pre-training. Default None.
|
Defaults to None, which means no extra hidden layer.
|
||||||
act_cfg (dict): The activation config. Only available during
|
act_cfg (dict): The activation config. Only available during
|
||||||
pre-training. Defaults to Tanh.
|
pre-training. Defaults to ``dict(type='Tanh')``.
|
||||||
|
init_cfg (dict): The extra initialization configs. Defaults to
|
||||||
|
``dict(type='Constant', layer='Linear', val=0)``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -117,7 +117,42 @@ class BaseCutMixLayer(object, metaclass=ABCMeta):
|
|||||||
|
|
||||||
@AUGMENT.register_module(name='BatchCutMix')
|
@AUGMENT.register_module(name='BatchCutMix')
|
||||||
class BatchCutMixLayer(BaseCutMixLayer):
|
class BatchCutMixLayer(BaseCutMixLayer):
|
||||||
"""CutMix layer for batch CutMix."""
|
r"""CutMix layer for a batch of data.
|
||||||
|
|
||||||
|
CutMix is a method to improve the network's generalization capability. It's
|
||||||
|
proposed in `CutMix: Regularization Strategy to Train Strong Classifiers
|
||||||
|
with Localizable Features <https://arxiv.org/abs/1905.04899>`
|
||||||
|
|
||||||
|
With this method, patches are cut and pasted among training images where
|
||||||
|
the ground truth labels are also mixed proportionally to the area of the
|
||||||
|
patches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alpha (float): Parameters for Beta distribution to generate the
|
||||||
|
mixing ratio. It should be a positive number. More details
|
||||||
|
can be found in :class:`BatchMixupLayer`.
|
||||||
|
num_classes (int): The number of classes
|
||||||
|
prob (float): The probability to execute cutmix. It should be in
|
||||||
|
range [0, 1]. Defaults to 1.0.
|
||||||
|
cutmix_minmax (List[float], optional): The min/max area ratio of the
|
||||||
|
patches. If not None, the bounding-box of patches is uniform
|
||||||
|
sampled within this ratio range, and the ``alpha`` will be ignored.
|
||||||
|
Otherwise, the bounding-box is generated according to the
|
||||||
|
``alpha``. Defaults to None.
|
||||||
|
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
|
||||||
|
clipped by image borders. Defaults to True.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the ``cutmix_minmax`` is None, how to generate the bounding-box of
|
||||||
|
patches according to the ``alpha``?
|
||||||
|
|
||||||
|
First, generate a :math:`\lambda`, details can be found in
|
||||||
|
:class:`BatchMixupLayer`. And then, the area ratio of the bounding-box
|
||||||
|
is calculated by:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{ratio} = \sqrt{1-\lambda}
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(BatchCutMixLayer, self).__init__(*args, **kwargs)
|
super(BatchCutMixLayer, self).__init__(*args, **kwargs)
|
||||||
|
@ -12,7 +12,8 @@ class BaseMixupLayer(object, metaclass=ABCMeta):
|
|||||||
"""Base class for MixupLayer.
|
"""Base class for MixupLayer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
alpha (float): Parameters for Beta distribution.
|
alpha (float): Parameters for Beta distribution to generate the
|
||||||
|
mixing ratio. It should be a positive number.
|
||||||
num_classes (int): The number of classes.
|
num_classes (int): The number of classes.
|
||||||
prob (float): MixUp probability. It should be in range [0, 1].
|
prob (float): MixUp probability. It should be in range [0, 1].
|
||||||
Default to 1.0
|
Default to 1.0
|
||||||
@ -36,7 +37,29 @@ class BaseMixupLayer(object, metaclass=ABCMeta):
|
|||||||
|
|
||||||
@AUGMENT.register_module(name='BatchMixup')
|
@AUGMENT.register_module(name='BatchMixup')
|
||||||
class BatchMixupLayer(BaseMixupLayer):
|
class BatchMixupLayer(BaseMixupLayer):
|
||||||
"""Mixup layer for batch mixup."""
|
r"""Mixup layer for a batch of data.
|
||||||
|
|
||||||
|
Mixup is a method to reduces the memorization of corrupt labels and
|
||||||
|
increases the robustness to adversarial examples. It's
|
||||||
|
proposed in `mixup: Beyond Empirical Risk Minimization
|
||||||
|
<https://arxiv.org/abs/1710.09412>`
|
||||||
|
|
||||||
|
This method simply linearly mix pairs of data and their labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alpha (float): Parameters for Beta distribution to generate the
|
||||||
|
mixing ratio. It should be a positive number. More details
|
||||||
|
are in the note.
|
||||||
|
num_classes (int): The number of classes.
|
||||||
|
prob (float): The probability to execute mixup. It should be in
|
||||||
|
range [0, 1]. Default sto 1.0.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The :math:`\alpha` (``alpha``) determines a random distribution
|
||||||
|
:math:`Beta(\alpha, \alpha)`. For each batch of data, we sample
|
||||||
|
a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random
|
||||||
|
distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(BatchMixupLayer, self).__init__(*args, **kwargs)
|
super(BatchMixupLayer, self).__init__(*args, **kwargs)
|
||||||
|
@ -10,27 +10,31 @@ from .utils import one_hot_encoding
|
|||||||
|
|
||||||
@AUGMENT.register_module(name='BatchResizeMix')
|
@AUGMENT.register_module(name='BatchResizeMix')
|
||||||
class BatchResizeMixLayer(BatchCutMixLayer):
|
class BatchResizeMixLayer(BatchCutMixLayer):
|
||||||
r"""ResizeMix Random Paste layer for batch ResizeMix.
|
r"""ResizeMix Random Paste layer for a batch of data.
|
||||||
|
|
||||||
The ResizeMix will resize an image to a small patch and paste it on another
|
The ResizeMix will resize an image to a small patch and paste it on another
|
||||||
image. More details can be found in `ResizeMix: Mixing Data with Preserved
|
image. It's proposed in `ResizeMix: Mixing Data with Preserved Object
|
||||||
Object Information and True Labels <https://arxiv.org/abs/2012.11101>`_
|
Information and True Labels <https://arxiv.org/abs/2012.11101>`_
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
alpha (float): Parameters for Beta distribution. Positive(>0)
|
alpha (float): Parameters for Beta distribution to generate the
|
||||||
|
mixing ratio. It should be a positive number. More details
|
||||||
|
can be found in :class:`BatchMixupLayer`.
|
||||||
num_classes (int): The number of classes.
|
num_classes (int): The number of classes.
|
||||||
lam_min(float): The minimum value of lam. Defaults to 0.1.
|
lam_min(float): The minimum value of lam. Defaults to 0.1.
|
||||||
lam_max(float): The maximum value of lam. Defaults to 0.8.
|
lam_max(float): The maximum value of lam. Defaults to 0.8.
|
||||||
interpolation (str): algorithm used for upsampling:
|
interpolation (str): algorithm used for upsampling:
|
||||||
'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'.
|
'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' |
|
||||||
Default to 'bilinear'.
|
'area'. Default to 'bilinear'.
|
||||||
prob (float): mix probability. It should be in range [0, 1].
|
prob (float): The probability to execute resizemix. It should be in
|
||||||
Default to 1.0.
|
range [0, 1]. Defaults to 1.0.
|
||||||
cutmix_minmax (List[float], optional): cutmix min/max image ratio.
|
cutmix_minmax (List[float], optional): The min/max area ratio of the
|
||||||
(as percent of image size). When cutmix_minmax is not None, we
|
patches. If not None, the bounding-box of patches is uniform
|
||||||
generate cutmix bounding-box using cutmix_minmax instead of alpha
|
sampled within this ratio range, and the ``alpha`` will be ignored.
|
||||||
|
Otherwise, the bounding-box is generated according to the
|
||||||
|
``alpha``. Defaults to None.
|
||||||
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
|
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
|
||||||
clipped by image borders. Default to True
|
clipped by image borders. Defaults to True
|
||||||
**kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`.
|
**kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@ -45,7 +49,7 @@ class BatchResizeMixLayer(BatchCutMixLayer):
|
|||||||
And the resize ratio of source images is calculated by :math:`\lambda`:
|
And the resize ratio of source images is calculated by :math:`\lambda`:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{ratio} = \sqrt{1-lam}
|
\text{ratio} = \sqrt{1-\lambda}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -8,6 +8,8 @@ from mmcv.utils import digit_version
|
|||||||
|
|
||||||
|
|
||||||
def is_tracing() -> bool:
|
def is_tracing() -> bool:
|
||||||
|
"""Determine whether the model is called during the tracing of code with
|
||||||
|
``torch.jit.trace``."""
|
||||||
if digit_version(torch.__version__) >= digit_version('1.6.0'):
|
if digit_version(torch.__version__) >= digit_version('1.6.0'):
|
||||||
on_trace = torch.jit.is_tracing()
|
on_trace = torch.jit.is_tracing()
|
||||||
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
|
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
|
||||||
@ -26,6 +28,15 @@ def is_tracing() -> bool:
|
|||||||
|
|
||||||
# From PyTorch internals
|
# From PyTorch internals
|
||||||
def _ntuple(n):
|
def _ntuple(n):
|
||||||
|
"""A `to_tuple` function generator.
|
||||||
|
|
||||||
|
It returns a function, this function will repeat the input to a tuple of
|
||||||
|
length ``n`` if the input is not an Iterable object, otherwise, return the
|
||||||
|
input directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The number of the target length.
|
||||||
|
"""
|
||||||
|
|
||||||
def parse(x):
|
def parse(x):
|
||||||
if isinstance(x, collections.abc.Iterable):
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
@ -7,6 +7,16 @@ from mmcv.utils import get_logger
|
|||||||
|
|
||||||
|
|
||||||
def get_root_logger(log_file=None, log_level=logging.INFO):
|
def get_root_logger(log_file=None, log_level=logging.INFO):
|
||||||
|
"""Get root logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_file (str, optional): File path of log. Defaults to None.
|
||||||
|
log_level (int, optional): The level of logger.
|
||||||
|
Defaults to :obj:`logging.INFO`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`logging.Logger`: The obtained logger
|
||||||
|
"""
|
||||||
return get_logger('mmcls', log_file, log_level)
|
return get_logger('mmcls', log_file, log_level)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user