update docstring (#20)

pull/255/head
humu789 2021-12-23 12:02:39 +08:00 committed by GitHub
parent ea36480b14
commit b53b3950ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 28 additions and 28 deletions

View File

@ -4,24 +4,24 @@
### CWD ### CWD
Please refer to [CWD](/configs/distill/cwd/README.md) for details. Please refer to [CWD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/cwd/README.md) for details.
### WSLD ### WSLD
Please refer to [WSLD](/configs/distill/wsld/README.md) for details. Please refer to [WSLD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/wsld/README.md) for details.
### DARTS ### DARTS
Please refer to [DARTS](/configs/nas/darts/README.md) for details. Please refer to [DARTS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/darts/README.md) for details.
### DETNAS ### DETNAS
Please refer to [DETNAS](/configs/nas/detnas/README.md) for details. Please refer to [DETNAS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/detnas/README.md) for details.
### SPOS ### SPOS
Please refer to [SPOS](/configs/nas/spos/README.md) for details. Please refer to [SPOS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/README.md) for details.
### AUTOSLIM ### AUTOSLIM
Please refer to [AUTOSLIM](/configs/pruning/autoslim/README.md) for details. Please refer to [AUTOSLIM](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md) for details.

View File

@ -8,7 +8,7 @@ To test nas method, you can use following command
python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments] python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments]
``` ```
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for testing. An example for `mutable_cfg` can be found [here](/configs/nas/spos/SPOS_SHUFFLENET_300M.yaml). - `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for testing. An example for `mutable_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml).
The usage of optional arguments are the same as corresponding tasks like mmclassification, mmdetection and mmsegmentation. The usage of optional arguments are the same as corresponding tasks like mmclassification, mmdetection and mmsegmentation.
@ -35,7 +35,7 @@ To test pruning method, you can use following command
python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.channel_cfg=${CHANNEL_CFG_PATH} [optional arguments] python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.channel_cfg=${CHANNEL_CFG_PATH} [optional arguments]
``` ```
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](/configs/pruning/autoslim/README.md#test-a-subnet). - `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md#test-a-subnet).
## Distillation ## Distillation

View File

@ -31,7 +31,7 @@ python tools/${task}/search_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} [option
python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments] python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments]
``` ```
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for retraining. An example for `mutable_cfg` can be found [here](/configs/nas/spos/SPOS_SHUFFLENET_300M.yaml), and the usage can be found [here](/configs/nas/spos/README.md#subnet-retraining-on-imagenet). - `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for retraining. An example for `mutable_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/README.md#subnet-retraining-on-imagenet).
## Pruning ## Pruning
@ -45,7 +45,7 @@ python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.cha
Different from NAS, the argument that needs to be specified here is `channel_cfg` instead of `mutable_cfg`. Different from NAS, the argument that needs to be specified here is `channel_cfg` instead of `mutable_cfg`.
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](/configs/pruning/autoslim/README.md#subnet-retraining-on-imagenet). - `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md#subnet-retraining-on-imagenet).
## Distillation ## Distillation

View File

@ -10,6 +10,7 @@ class AlignMethodDistill(GeneralDistill):
super(AlignMethodDistill, self).__init__(**kwargs) super(AlignMethodDistill, self).__init__(**kwargs)
def train_step(self, data, optimizer): def train_step(self, data, optimizer):
with self.distiller.context_manager: with self.distiller.context_manager:
outputs = super().train_step(data, optimizer) outputs = super().train_step(data, optimizer)
return outputs return outputs

View File

@ -25,7 +25,7 @@ class GeneralDistill(BaseAlgorithm):
self.with_teacher_loss = with_teacher_loss self.with_teacher_loss = with_teacher_loss
def train_step(self, data, optimizer): def train_step(self, data, optimizer):
""""""
losses = dict() losses = dict()
if self.with_teacher_loss: if self.with_teacher_loss:
teacher_losses = self.distiller.exec_teacher_forward(data) teacher_losses = self.distiller.exec_teacher_forward(data)

View File

@ -14,7 +14,7 @@ class ChannelWiseDivergence(nn.Module):
Args: Args:
tau (float): Temperature coefficient. Defaults to 1.0. tau (float): Temperature coefficient. Defaults to 1.0.
weight (float): Weight of loss. Defaults to 1.0. loss_weight (float): Weight of loss. Defaults to 1.0.
""" """
def __init__( def __init__(

View File

@ -8,8 +8,6 @@ from ..builder import LOSSES
@LOSSES.register_module() @LOSSES.register_module()
class WSLD(nn.Module): class WSLD(nn.Module):
def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000):
"""PyTorch version of `Rethinking Soft Labels for Knowledge """PyTorch version of `Rethinking Soft Labels for Knowledge
Distillation: A Bias-Variance Tradeoff Perspective Distillation: A Bias-Variance Tradeoff Perspective
<https://arxiv.org/abs/2102.00650>`_. <https://arxiv.org/abs/2102.00650>`_.
@ -19,6 +17,8 @@ class WSLD(nn.Module):
weight (float): Weight of loss. Defaults to 1.0. weight (float): Weight of loss. Defaults to 1.0.
num_classes (int): Defaults to 1000. num_classes (int): Defaults to 1000.
""" """
def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000):
super(WSLD, self).__init__() super(WSLD, self).__init__()
self.tau = tau self.tau = tau

View File

@ -69,12 +69,11 @@ def register_parser(parser_dict, name=None, force=False):
@PRUNERS.register_module() @PRUNERS.register_module()
class StructurePruner(BaseModule, metaclass=ABCMeta): class StructurePruner(BaseModule, metaclass=ABCMeta):
"""Base class for structure pruning. """Base class for structure pruning. This class defines the basic functions
of a structure pruner. Any pruner that inherits this class should at least
This class defines the basic functions of a structure pruner. Any pruner define its own `sample_subnet` and `set_min_channel` functions. This part
that inherits this class should at least define its own `sample_subnet` and is being continuously optimized, and there may be major changes in the
`set_min_channel` functions. This part is being continuously optimized, and future.
there may be major changes in the future.
Args: Args:
except_start_keys (List[str]): the module whose name start with a except_start_keys (List[str]): the module whose name start with a