update docstring (#20)

pull/255/head
humu789 2021-12-23 12:02:39 +08:00 committed by GitHub
parent ea36480b14
commit b53b3950ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 28 additions and 28 deletions

View File

@ -1,5 +1,5 @@
_base_ = [
'../../_base_/datasets/mmcls/cifar10_bs16.py',
'../../_base_/datasets/mmcls/cifar10_bs16.py',
'../../_base_/mmcls_runtime.py'
]

View File

@ -4,24 +4,24 @@
### CWD
Please refer to [CWD](/configs/distill/cwd/README.md) for details.
Please refer to [CWD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/cwd/README.md) for details.
### WSLD
Please refer to [WSLD](/configs/distill/wsld/README.md) for details.
Please refer to [WSLD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/wsld/README.md) for details.
### DARTS
Please refer to [DARTS](/configs/nas/darts/README.md) for details.
Please refer to [DARTS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/darts/README.md) for details.
### DETNAS
Please refer to [DETNAS](/configs/nas/detnas/README.md) for details.
Please refer to [DETNAS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/detnas/README.md) for details.
### SPOS
Please refer to [SPOS](/configs/nas/spos/README.md) for details.
Please refer to [SPOS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/README.md) for details.
### AUTOSLIM
Please refer to [AUTOSLIM](/configs/pruning/autoslim/README.md) for details.
Please refer to [AUTOSLIM](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md) for details.

View File

@ -8,7 +8,7 @@ To test nas method, you can use following command
python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments]
```
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for testing. An example for `mutable_cfg` can be found [here](/configs/nas/spos/SPOS_SHUFFLENET_300M.yaml).
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for testing. An example for `mutable_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml).
The usage of optional arguments are the same as corresponding tasks like mmclassification, mmdetection and mmsegmentation.
@ -35,7 +35,7 @@ To test pruning method, you can use following command
python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.channel_cfg=${CHANNEL_CFG_PATH} [optional arguments]
```
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](/configs/pruning/autoslim/README.md#test-a-subnet).
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md#test-a-subnet).
## Distillation

View File

@ -31,7 +31,7 @@ python tools/${task}/search_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} [option
python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments]
```
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for retraining. An example for `mutable_cfg` can be found [here](/configs/nas/spos/SPOS_SHUFFLENET_300M.yaml), and the usage can be found [here](/configs/nas/spos/README.md#subnet-retraining-on-imagenet).
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for retraining. An example for `mutable_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/README.md#subnet-retraining-on-imagenet).
## Pruning
@ -45,7 +45,7 @@ python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.cha
Different from NAS, the argument that needs to be specified here is `channel_cfg` instead of `mutable_cfg`.
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](/configs/pruning/autoslim/README.md#subnet-retraining-on-imagenet).
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md#subnet-retraining-on-imagenet).
## Distillation

View File

@ -10,6 +10,7 @@ class AlignMethodDistill(GeneralDistill):
super(AlignMethodDistill, self).__init__(**kwargs)
def train_step(self, data, optimizer):
with self.distiller.context_manager:
outputs = super().train_step(data, optimizer)
return outputs

View File

@ -25,7 +25,7 @@ class GeneralDistill(BaseAlgorithm):
self.with_teacher_loss = with_teacher_loss
def train_step(self, data, optimizer):
""""""
losses = dict()
if self.with_teacher_loss:
teacher_losses = self.distiller.exec_teacher_forward(data)

View File

@ -14,7 +14,7 @@ class ChannelWiseDivergence(nn.Module):
Args:
tau (float): Temperature coefficient. Defaults to 1.0.
weight (float): Weight of loss. Defaults to 1.0.
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(

View File

@ -8,17 +8,17 @@ from ..builder import LOSSES
@LOSSES.register_module()
class WSLD(nn.Module):
"""PyTorch version of `Rethinking Soft Labels for Knowledge
Distillation: A Bias-Variance Tradeoff Perspective
<https://arxiv.org/abs/2102.00650>`_.
Args:
tau (float): Temperature coefficient. Defaults to 1.0.
weight (float): Weight of loss. Defaults to 1.0.
num_classes (int): Defaults to 1000.
"""
def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000):
"""PyTorch version of `Rethinking Soft Labels for Knowledge
Distillation: A Bias-Variance Tradeoff Perspective
<https://arxiv.org/abs/2102.00650>`_.
Args:
tau (float): Temperature coefficient. Defaults to 1.0.
weight (float): Weight of loss. Defaults to 1.0.
num_classes (int): Defaults to 1000.
"""
super(WSLD, self).__init__()
self.tau = tau

View File

@ -69,12 +69,11 @@ def register_parser(parser_dict, name=None, force=False):
@PRUNERS.register_module()
class StructurePruner(BaseModule, metaclass=ABCMeta):
"""Base class for structure pruning.
This class defines the basic functions of a structure pruner. Any pruner
that inherits this class should at least define its own `sample_subnet` and
`set_min_channel` functions. This part is being continuously optimized, and
there may be major changes in the future.
"""Base class for structure pruning. This class defines the basic functions
of a structure pruner. Any pruner that inherits this class should at least
define its own `sample_subnet` and `set_min_channel` functions. This part
is being continuously optimized, and there may be major changes in the
future.
Args:
except_start_keys (List[str]): the module whose name start with a