PJDong c650b3eee5
[Doc] Add Tutorial of KD,Pruning,NAS and Installation. (#255)
* [doc] add doc of installation, kd, nas, pruning

* [doc] add user guide part

* fix name of mmcv-full to mmcv

* [doc] fix mmcv based on zaida 's suggestion

* [doc] fix index error based on shiguang's comments
2022-08-30 19:48:18 +08:00

3.0 KiB

Knowledge Distillation

Here we show how to develop new KD algorithms with an example of SingleTeacherDistill.

  1. Register a new algorithm

Create a new file mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py, class SingleTeacherDistill inherits from class BaseAlgorithm

from mmrazor.registry import MODELS
from ..base import BaseAlgorithm

@ALGORITHMS.register_module()
class SingleTeacherDistill(BaseAlgorithm):
    def __init__(self, use_gt, **kwargs):
        super(Distillation, self).__init__(**kwargs)
        pass

    def train_step(self, data, optimizer):
        pass
  1. Develop connectors (Optional) .

Take ConvModuleConnector as an example.

from mmrazor.registry import MODELS
from .base_connector import BaseConnector

@MODELS.register_module()
class ConvModuleConncetor(BaseConnector):
    def __init__(self, in_channel, out_channel, kernel_size = 1, stride = 1):
        ...

    def forward_train(self, feature):
        ...
  1. Develop distiller.

Take ConfigurableDistiller as an example.

from .base_distiller import BaseDistiller
from mmrazor.registry import MODELS


@MODELS.register_module()
class ConfigurableDistiller(BaseDistiller):
    def __init__(self,
                 student_recorders = None,
                 teacher_recorders = None,
                 distill_deliveries = None,
                 connectors = None,
                 distill_losses = None,
                 loss_forward_mappings = None):
        ...

     def build_connectors(self, connectors):
        ...

     def build_distill_losses(self, losses):
        ...

     def compute_distill_losses(self):
        ...
  1. Develop custom loss (Optional).

Here we take L1Loss as an example. Create a new file in mmrazor/models/losses/l1_loss.py.

from mmrazor.registry import MODELS

@MODELS.register_module()
class L1Loss(nn.Module):
    def __init__(
        self,
        loss_weight: float = 1.0,
        size_average: Optional[bool] = None,
        reduce: Optional[bool] = None,
        reduction: str = 'mean',
    ) -> None:
        super().__init__()
        ...

    def forward(self, s_feature, t_feature):
        loss = F.l1_loss(s_feature, t_feature, self.size_average, self.reduce,
                         self.reduction)
        return self.loss_weight * loss
  1. Import the class

You can either add the following line to mmrazor/models/algorithms/__init__.py

from .single_teacher_distill import SingleTeacherDistill

__all__ = [..., 'SingleTeacherDistill']

or alternatively add

custom_imports = dict(
    imports=['mmrazor.models.algorithms.distill.configurable.single_teacher_distill'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

  1. Use the algorithm in your config file
algorithm = dict(
    type='Distill',
    distiller=dict(type='SingleTeacherDistill', ...),
    # you can also use your new algorithm components here
    ...
)