update docstring (#20)
parent
ea36480b14
commit
b53b3950ae
|
@ -1,5 +1,5 @@
|
|||
_base_ = [
|
||||
'../../_base_/datasets/mmcls/cifar10_bs16.py',
|
||||
'../../_base_/datasets/mmcls/cifar10_bs16.py',
|
||||
'../../_base_/mmcls_runtime.py'
|
||||
]
|
||||
|
||||
|
|
|
@ -4,24 +4,24 @@
|
|||
|
||||
### CWD
|
||||
|
||||
Please refer to [CWD](/configs/distill/cwd/README.md) for details.
|
||||
Please refer to [CWD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/cwd/README.md) for details.
|
||||
|
||||
### WSLD
|
||||
|
||||
Please refer to [WSLD](/configs/distill/wsld/README.md) for details.
|
||||
Please refer to [WSLD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/wsld/README.md) for details.
|
||||
|
||||
### DARTS
|
||||
|
||||
Please refer to [DARTS](/configs/nas/darts/README.md) for details.
|
||||
Please refer to [DARTS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/darts/README.md) for details.
|
||||
|
||||
### DETNAS
|
||||
|
||||
Please refer to [DETNAS](/configs/nas/detnas/README.md) for details.
|
||||
Please refer to [DETNAS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/detnas/README.md) for details.
|
||||
|
||||
### SPOS
|
||||
|
||||
Please refer to [SPOS](/configs/nas/spos/README.md) for details.
|
||||
Please refer to [SPOS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/README.md) for details.
|
||||
|
||||
### AUTOSLIM
|
||||
|
||||
Please refer to [AUTOSLIM](/configs/pruning/autoslim/README.md) for details.
|
||||
Please refer to [AUTOSLIM](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md) for details.
|
||||
|
|
|
@ -8,7 +8,7 @@ To test nas method, you can use following command
|
|||
python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments]
|
||||
```
|
||||
|
||||
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for testing. An example for `mutable_cfg` can be found [here](/configs/nas/spos/SPOS_SHUFFLENET_300M.yaml).
|
||||
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for testing. An example for `mutable_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml).
|
||||
|
||||
The usage of optional arguments are the same as corresponding tasks like mmclassification, mmdetection and mmsegmentation.
|
||||
|
||||
|
@ -35,7 +35,7 @@ To test pruning method, you can use following command
|
|||
python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.channel_cfg=${CHANNEL_CFG_PATH} [optional arguments]
|
||||
```
|
||||
|
||||
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](/configs/pruning/autoslim/README.md#test-a-subnet).
|
||||
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md#test-a-subnet).
|
||||
|
||||
## Distillation
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ python tools/${task}/search_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} [option
|
|||
python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments]
|
||||
```
|
||||
|
||||
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for retraining. An example for `mutable_cfg` can be found [here](/configs/nas/spos/SPOS_SHUFFLENET_300M.yaml), and the usage can be found [here](/configs/nas/spos/README.md#subnet-retraining-on-imagenet).
|
||||
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for retraining. An example for `mutable_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/README.md#subnet-retraining-on-imagenet).
|
||||
|
||||
## Pruning
|
||||
|
||||
|
@ -45,7 +45,7 @@ python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.cha
|
|||
|
||||
Different from NAS, the argument that needs to be specified here is `channel_cfg` instead of `mutable_cfg`.
|
||||
|
||||
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](/configs/pruning/autoslim/README.md#subnet-retraining-on-imagenet).
|
||||
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md#subnet-retraining-on-imagenet).
|
||||
|
||||
## Distillation
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ class AlignMethodDistill(GeneralDistill):
|
|||
super(AlignMethodDistill, self).__init__(**kwargs)
|
||||
|
||||
def train_step(self, data, optimizer):
|
||||
|
||||
with self.distiller.context_manager:
|
||||
outputs = super().train_step(data, optimizer)
|
||||
return outputs
|
||||
|
|
|
@ -25,7 +25,7 @@ class GeneralDistill(BaseAlgorithm):
|
|||
self.with_teacher_loss = with_teacher_loss
|
||||
|
||||
def train_step(self, data, optimizer):
|
||||
|
||||
""""""
|
||||
losses = dict()
|
||||
if self.with_teacher_loss:
|
||||
teacher_losses = self.distiller.exec_teacher_forward(data)
|
||||
|
|
|
@ -14,7 +14,7 @@ class ChannelWiseDivergence(nn.Module):
|
|||
|
||||
Args:
|
||||
tau (float): Temperature coefficient. Defaults to 1.0.
|
||||
weight (float): Weight of loss. Defaults to 1.0.
|
||||
loss_weight (float): Weight of loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -8,17 +8,17 @@ from ..builder import LOSSES
|
|||
|
||||
@LOSSES.register_module()
|
||||
class WSLD(nn.Module):
|
||||
"""PyTorch version of `Rethinking Soft Labels for Knowledge
|
||||
Distillation: A Bias-Variance Tradeoff Perspective
|
||||
<https://arxiv.org/abs/2102.00650>`_.
|
||||
|
||||
Args:
|
||||
tau (float): Temperature coefficient. Defaults to 1.0.
|
||||
weight (float): Weight of loss. Defaults to 1.0.
|
||||
num_classes (int): Defaults to 1000.
|
||||
"""
|
||||
|
||||
def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000):
|
||||
"""PyTorch version of `Rethinking Soft Labels for Knowledge
|
||||
Distillation: A Bias-Variance Tradeoff Perspective
|
||||
<https://arxiv.org/abs/2102.00650>`_.
|
||||
|
||||
Args:
|
||||
tau (float): Temperature coefficient. Defaults to 1.0.
|
||||
weight (float): Weight of loss. Defaults to 1.0.
|
||||
num_classes (int): Defaults to 1000.
|
||||
"""
|
||||
super(WSLD, self).__init__()
|
||||
|
||||
self.tau = tau
|
||||
|
|
|
@ -69,12 +69,11 @@ def register_parser(parser_dict, name=None, force=False):
|
|||
|
||||
@PRUNERS.register_module()
|
||||
class StructurePruner(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for structure pruning.
|
||||
|
||||
This class defines the basic functions of a structure pruner. Any pruner
|
||||
that inherits this class should at least define its own `sample_subnet` and
|
||||
`set_min_channel` functions. This part is being continuously optimized, and
|
||||
there may be major changes in the future.
|
||||
"""Base class for structure pruning. This class defines the basic functions
|
||||
of a structure pruner. Any pruner that inherits this class should at least
|
||||
define its own `sample_subnet` and `set_min_channel` functions. This part
|
||||
is being continuously optimized, and there may be major changes in the
|
||||
future.
|
||||
|
||||
Args:
|
||||
except_start_keys (List[str]): the module whose name start with a
|
||||
|
|
Loading…
Reference in New Issue