update docstring (#20)
parent
ea36480b14
commit
b53b3950ae
|
@ -4,24 +4,24 @@
|
||||||
|
|
||||||
### CWD
|
### CWD
|
||||||
|
|
||||||
Please refer to [CWD](/configs/distill/cwd/README.md) for details.
|
Please refer to [CWD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/cwd/README.md) for details.
|
||||||
|
|
||||||
### WSLD
|
### WSLD
|
||||||
|
|
||||||
Please refer to [WSLD](/configs/distill/wsld/README.md) for details.
|
Please refer to [WSLD](https://github.com/open-mmlab/mmrazor/blob/master/configs/distill/wsld/README.md) for details.
|
||||||
|
|
||||||
### DARTS
|
### DARTS
|
||||||
|
|
||||||
Please refer to [DARTS](/configs/nas/darts/README.md) for details.
|
Please refer to [DARTS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/darts/README.md) for details.
|
||||||
|
|
||||||
### DETNAS
|
### DETNAS
|
||||||
|
|
||||||
Please refer to [DETNAS](/configs/nas/detnas/README.md) for details.
|
Please refer to [DETNAS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/detnas/README.md) for details.
|
||||||
|
|
||||||
### SPOS
|
### SPOS
|
||||||
|
|
||||||
Please refer to [SPOS](/configs/nas/spos/README.md) for details.
|
Please refer to [SPOS](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/README.md) for details.
|
||||||
|
|
||||||
### AUTOSLIM
|
### AUTOSLIM
|
||||||
|
|
||||||
Please refer to [AUTOSLIM](/configs/pruning/autoslim/README.md) for details.
|
Please refer to [AUTOSLIM](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md) for details.
|
||||||
|
|
|
@ -8,7 +8,7 @@ To test nas method, you can use following command
|
||||||
python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments]
|
python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments]
|
||||||
```
|
```
|
||||||
|
|
||||||
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for testing. An example for `mutable_cfg` can be found [here](/configs/nas/spos/SPOS_SHUFFLENET_300M.yaml).
|
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for testing. An example for `mutable_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml).
|
||||||
|
|
||||||
The usage of optional arguments are the same as corresponding tasks like mmclassification, mmdetection and mmsegmentation.
|
The usage of optional arguments are the same as corresponding tasks like mmclassification, mmdetection and mmsegmentation.
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ To test pruning method, you can use following command
|
||||||
python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.channel_cfg=${CHANNEL_CFG_PATH} [optional arguments]
|
python tools/${task}/test_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} --cfg-options algorithm.channel_cfg=${CHANNEL_CFG_PATH} [optional arguments]
|
||||||
```
|
```
|
||||||
|
|
||||||
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](/configs/pruning/autoslim/README.md#test-a-subnet).
|
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md#test-a-subnet).
|
||||||
|
|
||||||
## Distillation
|
## Distillation
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ python tools/${task}/search_${task}.py ${CONFIG_FILE} ${CHECKPOINT_PATH} [option
|
||||||
python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments]
|
python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.mutable_cfg=${MUTABLE_CFG_PATH} [optional arguments]
|
||||||
```
|
```
|
||||||
|
|
||||||
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for retraining. An example for `mutable_cfg` can be found [here](/configs/nas/spos/SPOS_SHUFFLENET_300M.yaml), and the usage can be found [here](/configs/nas/spos/README.md#subnet-retraining-on-imagenet).
|
- `MUTABLE_CFG_PATH`: Path of `mutable_cfg`. `mutable_cfg` represents **config for mutable of the subnet searched out**, used to specify different subnets for retraining. An example for `mutable_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/SPOS_SHUFFLENETV2_330M_IN1k_PAPER.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/nas/spos/README.md#subnet-retraining-on-imagenet).
|
||||||
|
|
||||||
## Pruning
|
## Pruning
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ python tools/${task}/train_${task}.py ${CONFIG_FILE} --cfg-options algorithm.cha
|
||||||
|
|
||||||
Different from NAS, the argument that needs to be specified here is `channel_cfg` instead of `mutable_cfg`.
|
Different from NAS, the argument that needs to be specified here is `channel_cfg` instead of `mutable_cfg`.
|
||||||
|
|
||||||
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](/configs/pruning/autoslim/README.md#subnet-retraining-on-imagenet).
|
- `CHANNEL_CFG_PATH`: Path of `channel_cfg`. `channel_cfg` represents **config for channel of the subnet searched out**, used to specify different subnets for testing. An example for `channel_cfg` can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/AUTOSLIM_MBV2_220M_OFFICIAL.yaml), and the usage can be found [here](https://github.com/open-mmlab/mmrazor/blob/master/configs/pruning/autoslim/README.md#subnet-retraining-on-imagenet).
|
||||||
|
|
||||||
## Distillation
|
## Distillation
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ class AlignMethodDistill(GeneralDistill):
|
||||||
super(AlignMethodDistill, self).__init__(**kwargs)
|
super(AlignMethodDistill, self).__init__(**kwargs)
|
||||||
|
|
||||||
def train_step(self, data, optimizer):
|
def train_step(self, data, optimizer):
|
||||||
|
|
||||||
with self.distiller.context_manager:
|
with self.distiller.context_manager:
|
||||||
outputs = super().train_step(data, optimizer)
|
outputs = super().train_step(data, optimizer)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
@ -25,7 +25,7 @@ class GeneralDistill(BaseAlgorithm):
|
||||||
self.with_teacher_loss = with_teacher_loss
|
self.with_teacher_loss = with_teacher_loss
|
||||||
|
|
||||||
def train_step(self, data, optimizer):
|
def train_step(self, data, optimizer):
|
||||||
|
""""""
|
||||||
losses = dict()
|
losses = dict()
|
||||||
if self.with_teacher_loss:
|
if self.with_teacher_loss:
|
||||||
teacher_losses = self.distiller.exec_teacher_forward(data)
|
teacher_losses = self.distiller.exec_teacher_forward(data)
|
||||||
|
|
|
@ -14,7 +14,7 @@ class ChannelWiseDivergence(nn.Module):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tau (float): Temperature coefficient. Defaults to 1.0.
|
tau (float): Temperature coefficient. Defaults to 1.0.
|
||||||
weight (float): Weight of loss. Defaults to 1.0.
|
loss_weight (float): Weight of loss. Defaults to 1.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -8,8 +8,6 @@ from ..builder import LOSSES
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@LOSSES.register_module()
|
||||||
class WSLD(nn.Module):
|
class WSLD(nn.Module):
|
||||||
|
|
||||||
def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000):
|
|
||||||
"""PyTorch version of `Rethinking Soft Labels for Knowledge
|
"""PyTorch version of `Rethinking Soft Labels for Knowledge
|
||||||
Distillation: A Bias-Variance Tradeoff Perspective
|
Distillation: A Bias-Variance Tradeoff Perspective
|
||||||
<https://arxiv.org/abs/2102.00650>`_.
|
<https://arxiv.org/abs/2102.00650>`_.
|
||||||
|
@ -19,6 +17,8 @@ class WSLD(nn.Module):
|
||||||
weight (float): Weight of loss. Defaults to 1.0.
|
weight (float): Weight of loss. Defaults to 1.0.
|
||||||
num_classes (int): Defaults to 1000.
|
num_classes (int): Defaults to 1000.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000):
|
||||||
super(WSLD, self).__init__()
|
super(WSLD, self).__init__()
|
||||||
|
|
||||||
self.tau = tau
|
self.tau = tau
|
||||||
|
|
|
@ -69,12 +69,11 @@ def register_parser(parser_dict, name=None, force=False):
|
||||||
|
|
||||||
@PRUNERS.register_module()
|
@PRUNERS.register_module()
|
||||||
class StructurePruner(BaseModule, metaclass=ABCMeta):
|
class StructurePruner(BaseModule, metaclass=ABCMeta):
|
||||||
"""Base class for structure pruning.
|
"""Base class for structure pruning. This class defines the basic functions
|
||||||
|
of a structure pruner. Any pruner that inherits this class should at least
|
||||||
This class defines the basic functions of a structure pruner. Any pruner
|
define its own `sample_subnet` and `set_min_channel` functions. This part
|
||||||
that inherits this class should at least define its own `sample_subnet` and
|
is being continuously optimized, and there may be major changes in the
|
||||||
`set_min_channel` functions. This part is being continuously optimized, and
|
future.
|
||||||
there may be major changes in the future.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
except_start_keys (List[str]): the module whose name start with a
|
except_start_keys (List[str]): the module whose name start with a
|
||||||
|
|
Loading…
Reference in New Issue