diff --git a/configs/_base_/datasets/mmcls/cifar100_bs16_auto_aug.py b/configs/_base_/datasets/mmcls/cifar100_bs16_auto_aug.py new file mode 100644 index 00000000..46c31ac7 --- /dev/null +++ b/configs/_base_/datasets/mmcls/cifar100_bs16_auto_aug.py @@ -0,0 +1,50 @@ +_base_ = ['./pipelines/auto_aug_cifar.py'] + +# dataset settings +dataset_type = 'CIFAR100' +preprocess_cfg = dict( + # RGB format normalization parameters + mean=[129.304, 124.070, 112.434], + std=[68.170, 65.392, 70.418], + # loaded images are already RGB format + to_rgb=False) + +train_pipeline = [ + dict(type='RandomCrop', crop_size=32, padding=4), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='Cutout', shape=16, pad_val=0), + dict(type='AutoAugment', policies={{_base_.policy_cifar}}), + dict(type='PackClsInputs'), +] + +test_pipeline = [ + dict(type='PackClsInputs'), +] + +train_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar100', + test_mode=False, + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar100/', + test_mode=True, + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='Accuracy', topk=(1, 5)) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/mmcls/pipelines/auto_aug_cifar.py b/configs/_base_/datasets/mmcls/pipelines/auto_aug_cifar.py new file mode 100644 index 00000000..4767a8fe --- /dev/null +++ b/configs/_base_/datasets/mmcls/pipelines/auto_aug_cifar.py @@ -0,0 +1,125 @@ +# Policy for CIFAR, refer to +# https://github.com/DeepVoltaire/AutoAugment/blame/master/autoaugment.py +policy_cifar = [ + # Group 1 + [ + dict(type='Invert', prob=0.1), + dict(type='Contrast', magnitude=0.5, prob=0.2) + ], + [ + dict(type='Rotate', angle=10., prob=0.7), + dict(type='Translate', magnitude=150 / 331, prob=0.3) + ], + [ + dict(type='Sharpness', magnitude=0.9, prob=0.8), + dict(type='Sharpness', magnitude=0.3, prob=0.9) + ], + [ + dict( + type='Shear', + magnitude=0.3 / 9 * 8, + direction='vertical', + prob=0.5), + dict( + type='Translate', + magnitude=150 / 331, + direction='vertical', + prob=0.3) + ], + [dict(type='AutoContrast', prob=0.5), + dict(type='Equalize', prob=0.9)], + # Group 2 + [ + dict( + type='Shear', + magnitude=0.3 / 9 * 7, + direction='vertical', + prob=0.2), + dict(type='Posterize', bits=5, prob=0.3) + ], + [ + dict(type='ColorTransform', magnitude=0.3, prob=0.4), + dict(type='Brightness', magnitude=0.7, prob=0.7) + ], + [ + dict(type='Sharpness', magnitude=1.0, prob=0.3), + dict(type='Brightness', magnitude=1.0, prob=0.7) + ], + [dict(type='Equalize', prob=0.6), + dict(type='Equalize', prob=0.5)], + [ + dict(type='Contrast', magnitude=0.6, prob=0.6), + dict(type='Sharpness', magnitude=0.4, prob=0.8), + ], + # Group 3 + [ + dict(type='ColorTransform', magnitude=0.6, prob=0.7), + dict(type='Translate', magnitude=150 / 331 / 9 * 7, prob=0.5) + ], + [dict(type='Equalize', prob=0.3), + dict(type='AutoContrast', prob=0.4)], + [ + dict( + type='Translate', + magnitude=150 / 331 / 9 * 2, + direction='vertical', + prob=0.4), + dict(type='Sharpness', magnitude=0.5, prob=0.2) + ], + [ + dict(type='Brightness', magnitude=0.5, prob=0.9), + dict(type='ColorTransform', magnitude=0.7, prob=0.2), + ], + [ + dict(type='Solarize', thr=256 / 9 * 7, prob=0.5), + dict(type='Invert', prob=0.0), + ], + # Group 4 + [dict(type='Equalize', prob=0.2), + dict(type='AutoContrast', prob=0.6)], + [dict(type='Equalize', prob=0.2), + dict(type='Equalize', prob=0.6)], + [ + dict(type='ColorTransform', magnitude=0.9, prob=0.9), + dict(type='Equalize', prob=0.6) + ], + [ + dict(type='AutoContrast', prob=0.8), + dict(type='Solarize', thr=256 / 9 * 1, prob=0.2), + ], + [ + dict(type='Brightness', magnitude=0.3, prob=0.1), + dict(type='ColorTransform', magnitude=0.0, prob=0.7) + ], + # Group 5 + [ + dict(type='Solarize', thr=256 / 9 * 4, prob=0.4), + dict(type='AutoContrast', prob=0.9) + ], + [ + dict( + type='Translate', + magnitude=150 / 331, + direction='vertical', + prob=0.9), + dict( + type='Translate', + magnitude=150 / 331, + direction='vertical', + prob=0.7) + ], + [ + dict(type='AutoContrast', prob=0.9), + dict(type='Solarize', thr=256 / 9 * 6, prob=0.8) + ], + [dict(type='Equalize', prob=0.8), + dict(type='Invert', prob=0.1)], + [ + dict( + type='Translate', + magnitude=150 / 331, + direction='vertical', + prob=0.7), + dict(type='AutoContrast', prob=0.9) + ] +] diff --git a/configs/distill/mmcls/byot/README.md b/configs/distill/mmcls/byot/README.md new file mode 100644 index 00000000..b3e125f9 --- /dev/null +++ b/configs/distill/mmcls/byot/README.md @@ -0,0 +1,55 @@ +# Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation + +## Abstract + +Convolutional neural networks have been widely deployed in various application scenarios. In order to extend the applications' boundaries to some accuracy-crucial domains, researchers have been investigating approaches to boost accuracy through either deeper or wider network structures, which brings with them the exponential increment of the computational and storage cost, delaying the responding time. In this paper, we propose a general training framework named self distillation, which notably enhances the performance (accuracy) of convolutional neural networks through shrinking the size of the network rather than aggrandizing it. Different from traditional knowledge distillation - a knowledge transformation methodology among networks, which forces student neural networks to approximate the softmax layer outputs of pre-trained teacher neural networks, the proposed self distillation framework distills knowledge within network itself. The networks are firstly divided into several sections. Then the knowledge in the deeper portion of the networks is squeezed into the shallow ones. Experiments further prove the generalization of the proposed self distillation framework: enhancement of accuracy at average level is 2.65%, varying from 0.61% in ResNeXt as minimum to 4.07% in VGG19 as maximum. In addition, it can also provide flexibility of depth-wise scalable inference on resource-limited edge devices.Our codes will be released on github soon. [Unofficial code](https://github.com/luanyunteng/pytorch-be-your-own-teacher) + +## Pipeline + +![pipeline](../../../../docs/en/imgs/model_zoo/byot/byot.png) + +## Results and models + +#### Classification + +| Location | Dataset | Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download | | +| :\------: | :------: | :-------------------------------------------------------: | :-------: | :------: | :-------: | :-------: | :--------------------------------------------------------------------------------------------------: | +| logits | CIFAR100 | [R18_BYOT](./byot_logits_resnet18_cifar100_8xb16_in1k.py) | 11.22 | 0.56 | 80.66 | 95.76 | [model & log](https://autolink.sensetime.com/pages/model/share/08ad706f-b3d4-4854-8019-e0b43607f001) | + +## Citation + +```latex +@ARTICLE{2019arXiv190508094Z, + author = {{Zhang}, Linfeng and {Song}, Jiebo and {Gao}, Anni and {Chen}, Jingwei and {Bao}, Chenglong and {Ma}, Kaisheng}, + title = {Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation}, + journal = {arXiv e-prints}, + keywords = {Computer Science - Machine Learning, Statistics - Machine Learning}, + year = 2019, + month = may, + eid = {arXiv:1905.08094}, + pages = {arXiv:1905.08094}, +archivePrefix = {arXiv}, + eprint = {1905.08094}, + primaryClass = {cs.LG}, + adsurl = {https://ui.adsabs.harvard.edu/abs/2019arXiv190508094Z}, + adsnote = {Provided by the SAO/NASA Astrophysics Data System} +} +``` + +## Get Started + +### Distillation training. + +```bash +sh tools/slurm_train.sh $PARTITION $JOB_NAME \ + configs/distill/mmcls/byot/byot_logits_resnet18_cifar100_8xb16_in1k.py\ + $WORK_DIR +``` + +### Test + +```bash +sh tools/slurm_test.sh $PARTITION $JOB_NAME \ + configs/distill/mmcls/byot/byot_logits_resnet18_cifar100_8xb16_in1k.py\ + $WORK_DIR/latest.sh --eval $EVAL_SETTING +``` diff --git a/configs/distill/mmcls/byot/byot_logits_resnet18_cifar100_8xb16_in1k.py b/configs/distill/mmcls/byot/byot_logits_resnet18_cifar100_8xb16_in1k.py new file mode 100644 index 00000000..9f5fc8b7 --- /dev/null +++ b/configs/distill/mmcls/byot/byot_logits_resnet18_cifar100_8xb16_in1k.py @@ -0,0 +1,155 @@ +_base_ = [ + '../../../_base_/datasets/mmcls/cifar100_bs16_auto_aug.py', + 'mmcls::_base_/schedules/cifar10_bs128.py', + 'mmcls::_base_/default_runtime.py' +] + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0005)) +param_scheduler = dict( + type='MultiStepLR', by_epoch=True, milestones=[80, 160, 240], gamma=0.1) +train_cfg = dict(by_epoch=True, max_epochs=250, val_interval=1) + +model = dict( + _scope_='mmrazor', + type='SelfDistill', + data_preprocessor=dict( + type='ImgDataPreprocessor', + # RGB format normalization parameters + mean=[129.304, 124.070, 112.434], + std=[68.170, 65.392, 70.418], + # convert image from BGR to RGB + bgr_to_rgb=False), + architecture=dict( + type='mmcls.ImageClassifier', + backbone=dict( + type='mmcls.ResNet_CIFAR', + depth=18, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='mmcls.GlobalAveragePooling'), + head=dict( + type='mmcls.LinearClsHead', + num_classes=100, + in_channels=512, + loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0))), + distiller=dict( + type='BYOTDistiller', + student_recorders=dict( + bb_s1=dict(type='ModuleOutputs', source='backbone.layer1.1.relu'), + bb_s2=dict(type='ModuleOutputs', source='backbone.layer2.1.relu'), + bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu')), + teacher_recorders=dict( + fc=dict(type='ModuleOutputs', source='head.fc'), + neck_gap=dict(type='ModuleOutputs', source='neck.gap'), + gt_labels=dict(type='ModuleInputs', source='head.loss_module')), + distill_losses=dict( + loss_fet_1=dict( + type='L2Loss', normalize=False, loss_weight=0.03, dist=True), + loss_label_1=dict(type='mmcls.CrossEntropyLoss', loss_weight=0.7), + loss_softl_1=dict(type='KLDivergence', tau=3, loss_weight=0.3), + loss_fet_2=dict( + type='L2Loss', normalize=False, loss_weight=0.03, dist=True), + loss_label_2=dict(type='mmcls.CrossEntropyLoss', loss_weight=0.7), + loss_softl_2=dict(type='KLDivergence', tau=3, loss_weight=0.3), + loss_fet_3=dict( + type='L2Loss', normalize=False, loss_weight=0., dist=True), + loss_label_3=dict(type='mmcls.CrossEntropyLoss', loss_weight=0.7), + loss_softl_3=dict(type='KLDivergence', tau=3, loss_weight=0.3)), + connectors=dict( + loss_s1_sfeat=dict( + type='BYOTConnector', + in_channel=64, + out_channel=512, + expansion=1, + kernel_size=3, + stride=2, + num_classes=100), + loss_s2_sfeat=dict( + type='BYOTConnector', + in_channel=128, + out_channel=512, + expansion=1, + kernel_size=3, + stride=2, + num_classes=100), + loss_s3_sfeat=dict( + type='BYOTConnector', + in_channel=256, + out_channel=512, + expansion=1, + kernel_size=3, + stride=2, + num_classes=100)), + loss_forward_mappings=dict( + loss_fet_1=dict( + s_feature=dict( + recorder='bb_s1', + from_student=True, + connector='loss_s1_sfeat', + connector_idx=0), + t_feature=dict(recorder='neck_gap', from_student=False)), + loss_label_1=dict( + cls_score=dict( + recorder='bb_s1', + from_student=True, + connector='loss_s1_sfeat', + connector_idx=1), + label=dict( + recorder='gt_labels', from_student=False, data_idx=1)), + loss_softl_1=dict( + preds_S=dict( + recorder='bb_s1', + from_student=True, + connector='loss_s1_sfeat', + connector_idx=1), + preds_T=dict(recorder='fc', from_student=False)), + loss_fet_2=dict( + s_feature=dict( + recorder='bb_s2', + from_student=True, + connector='loss_s2_sfeat', + connector_idx=0), + t_feature=dict(recorder='neck_gap', from_student=False)), + loss_label_2=dict( + cls_score=dict( + recorder='bb_s2', + from_student=True, + connector='loss_s2_sfeat', + connector_idx=1), + label=dict( + recorder='gt_labels', from_student=False, data_idx=1)), + loss_softl_2=dict( + preds_S=dict( + recorder='bb_s2', + from_student=True, + connector='loss_s2_sfeat', + connector_idx=1), + preds_T=dict(recorder='fc', from_student=False)), + loss_fet_3=dict( + s_feature=dict( + recorder='bb_s3', + from_student=True, + connector='loss_s3_sfeat', + connector_idx=0), + t_feature=dict(recorder='neck_gap', from_student=False)), + loss_label_3=dict( + cls_score=dict( + recorder='bb_s3', + from_student=True, + connector='loss_s3_sfeat', + connector_idx=1), + label=dict( + recorder='gt_labels', from_student=False, data_idx=1)), + loss_softl_3=dict( + preds_S=dict( + recorder='bb_s3', + from_student=True, + connector='loss_s3_sfeat', + connector_idx=1), + preds_T=dict(recorder='fc', from_student=False))))) + +find_unused_parameters = True + +val_cfg = dict(_delete_=True, type='mmrazor.SelfDistillValLoop') diff --git a/docs/en/imgs/model_zoo/byot/byot.png b/docs/en/imgs/model_zoo/byot/byot.png new file mode 100644 index 00000000..51416c3e Binary files /dev/null and b/docs/en/imgs/model_zoo/byot/byot.png differ diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index 2ba85f62..ce464dfd 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -3,12 +3,12 @@ from .hooks import DumpSubnetHook from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, - GreedySamplerTrainLoop, SingleTeacherDistillValLoop, - SlimmableValLoop) + GreedySamplerTrainLoop, SelfDistillValLoop, + SingleTeacherDistillValLoop, SlimmableValLoop) __all__ = [ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'AutoSlimValLoop' + 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop' ] diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 75e3ecc7..9715a4e6 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .autoslim_val_loop import AutoSlimValLoop from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop -from .distill_val_loop import SingleTeacherDistillValLoop +from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop from .evolution_search_loop import EvolutionSearchLoop from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop @@ -9,5 +9,5 @@ from .subnet_sampler_loop import GreedySamplerTrainLoop __all__ = [ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'AutoSlimValLoop' + 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop' ] diff --git a/mmrazor/engine/runner/distill_val_loop.py b/mmrazor/engine/runner/distill_val_loop.py index 8bb7dbc0..26314443 100644 --- a/mmrazor/engine/runner/distill_val_loop.py +++ b/mmrazor/engine/runner/distill_val_loop.py @@ -100,3 +100,42 @@ class SingleTeacherDistillValLoop(ValLoop): batch_idx=idx, data_batch=data_batch, outputs=outputs) + + +@LOOPS.register_module() +class SelfDistillValLoop(ValLoop): + """Knowledge Distill loop for validation. Only validate student. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 validation. Defaults to + False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False) -> None: + super().__init__(runner, dataloader, evaluator, fp16) + + def run(self): + """Launch validation.""" + self.runner.call_hook('before_val') + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + # compute student metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + student_metrics = dict() + for key, value in metrics.items(): + student_key = 'student.' + key + student_metrics[student_key] = value + + self.runner.call_hook('after_val_epoch', metrics=student_metrics) + self.runner.call_hook('after_val') diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index f0e4c541..20c6bd85 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import BaseAlgorithm -from .distill import FpnTeacherDistill, SingleTeacherDistill +from .distill import FpnTeacherDistill, SelfDistill, SingleTeacherDistill from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP from .pruning import SlimmableNetwork, SlimmableNetworkDDP __all__ = [ 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', 'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP', - 'Darts', 'DartsDDP' + 'Darts', 'DartsDDP', 'SelfDistill' ] diff --git a/mmrazor/models/algorithms/distill/__init__.py b/mmrazor/models/algorithms/distill/__init__.py index 4d348961..4d8473c7 100644 --- a/mmrazor/models/algorithms/distill/__init__.py +++ b/mmrazor/models/algorithms/distill/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .configurable import FpnTeacherDistill, SingleTeacherDistill +from .configurable import FpnTeacherDistill, SelfDistill, SingleTeacherDistill -__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill'] +__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill'] diff --git a/mmrazor/models/algorithms/distill/configurable/__init__.py b/mmrazor/models/algorithms/distill/configurable/__init__.py index efdafcdf..4638ed09 100644 --- a/mmrazor/models/algorithms/distill/configurable/__init__.py +++ b/mmrazor/models/algorithms/distill/configurable/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .fpn_teacher_distill import FpnTeacherDistill +from .self_distill import SelfDistill from .single_teacher_distill import SingleTeacherDistill -__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill'] +__all__ = ['SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill'] diff --git a/mmrazor/models/algorithms/distill/configurable/self_distill.py b/mmrazor/models/algorithms/distill/configurable/self_distill.py new file mode 100644 index 00000000..579df447 --- /dev/null +++ b/mmrazor/models/algorithms/distill/configurable/self_distill.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +from mmengine import BaseDataElement +from torch import nn + +from mmrazor.models.utils import add_prefix +from mmrazor.registry import MODELS +from ...base import BaseAlgorithm, LossResults + + +@MODELS.register_module() +class SelfDistill(BaseAlgorithm): + """``SelfDistill`` can be used to develop distill algorithms without + teacher. + + Args: + distiller (dict): The config dict for built distiller. Distiller may + have teacher. + student_trainable (bool): Whether the student is trainable. Defaults + to True. + calculate_student_loss (bool): Whether to calculate student loss + (original task loss) to update student model. Defaults to True. + """ + + def __init__(self, + distiller: dict, + student_trainable: bool = True, + calculate_student_loss: bool = True, + **kwargs) -> None: + super().__init__(**kwargs) + + self.distiller = MODELS.build(distiller) + # The student model will not calculate gradients and update parameters + # in some pretraining process. + self.student_trainable = student_trainable + + # The student loss will not be updated into ``losses`` in some + # pretraining process. + self.calculate_student_loss = calculate_student_loss + + # In ``ConfigurableDistller``, the recorder manager is just + # constructed, but not really initialized yet. + self.distiller.prepare_from_student(self.student) + # Still prepare from self-teacher. Teacher recorders of + # ``SelfDistiller`` hook from self.student but require detach(). + self.distiller.prepare_from_teacher(self.student) + + @property + def student(self) -> nn.Module: + """Alias for ``architecture``.""" + return self.architecture + + def loss( + self, + batch_inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + ) -> LossResults: + """Calculate losses from a batch of inputs and data samples.""" + + losses = dict() + + # If the `override_data` of a delivery is True, the delivery will + # override the origin data with the recorded data. + self.distiller.set_deliveries_override(True) + # Original task loss will not be used during some pretraining process. + if self.calculate_student_loss: + # teacher_recorders hook from student + with self.distiller.student_recorders, \ + self.distiller.teacher_recorders, \ + self.distiller.deliveries: + student_losses = self.student( + batch_inputs, data_samples, mode='loss') + losses.update(add_prefix(student_losses, 'student')) + else: + with self.distiller.student_recorders, \ + self.distiller.teacher_recorders, \ + self.distiller.deliveries: + if self.student_trainable: + _ = self.student(batch_inputs, data_samples, mode='loss') + else: + with torch.no_grad(): + _ = self.student( + batch_inputs, data_samples, mode='loss') + + # Automatically compute distill losses based on `loss_forward_mappings` + # The required data already exists in the recorders. + distill_losses = self.distiller.compute_distill_losses() + losses.update(add_prefix(distill_losses, 'distill')) + + return losses diff --git a/mmrazor/models/architectures/connectors/__init__.py b/mmrazor/models/architectures/connectors/__init__.py index c78bfe99..3823cfe6 100644 --- a/mmrazor/models/architectures/connectors/__init__.py +++ b/mmrazor/models/architectures/connectors/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .byot_connector import BYOTConnector from .convmodule_connector import ConvModuleConncetor -__all__ = ['ConvModuleConncetor'] +__all__ = ['ConvModuleConncetor', 'BYOTConnector'] diff --git a/mmrazor/models/architectures/connectors/byot_connector.py b/mmrazor/models/architectures/connectors/byot_connector.py new file mode 100644 index 00000000..657e2d4a --- /dev/null +++ b/mmrazor/models/architectures/connectors/byot_connector.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from math import log +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS +from ...ops.darts_series import DartsSepConv +from .base_connector import BaseConnector + + +@MODELS.register_module() +class BYOTConnector(BaseConnector): + """BYOTConnector connector that adds a self-attention with DartsSepConv. + + Args: + in_channel (int): The input channel of the DartsSepConv. + Use like input_tensor_channel = in_channel * expansion. + out_channel (int): The output channel of the DartsSepConv. + Use like output_tensor_channel = out_channel * expansion. + num_classes (int): The classification class num. + expansion (int): Expansion of DartsSepConv. Default to 4. + pool_size (int | tuple[int]): Average 2D pool size. Default to 4. + kernel_size (int | tuple[int]): Size of the convolving kernel in + DartsSepConv. Same as that in ``nn._ConvNd``. Default to 3. + stride (int | tuple[int]): Stride of the first layer in DartsSepConv. + Same as that in ``nn._ConvNd``. Default to 1. + init_cfg (dict, optional): The config to control the initialization. + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + num_classes: int, + expansion: int = 4, + pool_size: Union[int, Tuple[int]] = 4, + kernel_size: Union[int, Tuple[int]] = 3, + stride: Union[int, Tuple[int]] = 1, + init_cfg: Optional[Dict] = None, + ) -> None: + super().__init__(init_cfg) + self.attention = nn.Sequential( + DartsSepConv( + in_channels=in_channel * expansion, + out_channels=in_channel * expansion, + kernel_size=kernel_size, + stride=stride), nn.BatchNorm2d(in_channel * expansion), + nn.ReLU(), nn.Upsample(scale_factor=2, mode='bilinear'), + nn.Sigmoid()) + scala_num = log(out_channel / in_channel, 2) + assert scala_num.is_integer() + scala = [] + + _in_channel = in_channel + + for _ in range(int(scala_num)): + scala.append( + DartsSepConv( + in_channels=_in_channel * expansion, + out_channels=_in_channel * 2 * expansion, + kernel_size=kernel_size, + stride=stride)) + _in_channel *= 2 + scala.append(nn.AvgPool2d(pool_size)) + self.scala = nn.Sequential(*scala) + self.fc = nn.Linear(out_channel * expansion, num_classes) + + def forward_train(self, feature: torch.Tensor) -> torch.Tensor: + """Forward computation. + + Args: + feature (torch.Tensor): Input feature. + """ + feat = self.attention(feature) + feat = feat * feature + + feat = self.scala(feat) + feat = feat.view(feature.size(0), -1) + logits = self.fc(feat) + return (feat, logits) diff --git a/mmrazor/models/distillers/__init__.py b/mmrazor/models/distillers/__init__.py index b3aee9f9..90584682 100644 --- a/mmrazor/models/distillers/__init__.py +++ b/mmrazor/models/distillers/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_distiller import BaseDistiller +from .byot_distiller import BYOTDistiller from .configurable_distiller import ConfigurableDistiller -__all__ = ['ConfigurableDistiller', 'BaseDistiller'] +__all__ = ['ConfigurableDistiller', 'BaseDistiller', 'BYOTDistiller'] diff --git a/mmrazor/models/distillers/byot_distiller.py b/mmrazor/models/distillers/byot_distiller.py new file mode 100644 index 00000000..fca0d774 --- /dev/null +++ b/mmrazor/models/distillers/byot_distiller.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from mmrazor.registry import MODELS +from .configurable_distiller import ConfigurableDistiller + + +@MODELS.register_module() +class BYOTDistiller(ConfigurableDistiller): + """``BYOTDistiller`` inherits ``ConfigurableDistiller`` and only modifies + ``get_record()`` function to ``get_record_with_cidx()``. + + In ``BYOTDistiller``, ``self.teacher_recorder`` records self-teacher data + which requires detach(). + """ + + def get_record(self, + recorder: str, + from_student: bool, + record_idx: int = 0, + data_idx: Optional[int] = None, + connector: Optional[str] = None, + connector_idx: Optional[int] = None) -> List: + """According to each item in ``record_infos``, get the corresponding + record in ``recorder_manager``. + + Detach teacher_record. + """ + + if from_student: + recorder_ = self.student_recorders.get_recorder(recorder) + else: + recorder_ = self.teacher_recorders.get_recorder(recorder) + record_data = recorder_.get_record_data(record_idx, data_idx) + + if connector: + record_data = self.connectors[connector](record_data) + if connector_idx is not None: + record_data = record_data[connector_idx] + # Detach self-teacher output Tensor from model, assert hook tensor. + if not from_student: + record_data = record_data.detach() + + return record_data diff --git a/mmrazor/models/distillers/configurable_distiller.py b/mmrazor/models/distillers/configurable_distiller.py index a35e83e5..5899cd67 100644 --- a/mmrazor/models/distillers/configurable_distiller.py +++ b/mmrazor/models/distillers/configurable_distiller.py @@ -180,7 +180,8 @@ class ConfigurableDistiller(BaseDistiller): from_student: bool, record_idx: int = 0, data_idx: Optional[int] = None, - connector: Optional[str] = None) -> List: + connector: Optional[str] = None, + connector_idx: Optional[int] = None) -> List: """According to each item in ``record_infos``, get the corresponding record in ``recorder_manager``.""" @@ -192,6 +193,8 @@ class ConfigurableDistiller(BaseDistiller): if connector: record_data = self.connectors[connector](record_data) + if connector_idx is not None: + record_data = record_data[connector_idx] return record_data @@ -235,9 +238,10 @@ class ConfigurableDistiller(BaseDistiller): f'instance, but got {type(forward_mappings)}') loss_module = losses[loss_name] - loss_forward_keys = signature( - loss_module.forward).parameters.keys() - assert len(loss_forward_keys) == len(forward_mappings.keys()) + loss_forward_params = signature(loss_module.forward).parameters + loss_forward_keys = loss_forward_params.keys() + # Allow default params. + # Check non-default params, not len(params). for forward_key, record_info in forward_mappings.items(): assert forward_key in loss_forward_keys, \ @@ -245,6 +249,11 @@ class ConfigurableDistiller(BaseDistiller): {type(loss_module).__name__} forward, \ please check your config.' + if (loss_forward_params[forward_key].default == + loss_forward_params[forward_key].empty): + # default params without check + continue + assert 'recorder' in record_info, \ 'Each item of loss_forward_mappings should have ' \ '"recorder", pls check your config.' diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 452d1cba..a21a9513 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -2,6 +2,7 @@ from .ab_loss import ABLoss from .cwd import ChannelWiseDivergence from .decoupled_kd import DKDLoss +from .kd_soft_ce_loss import KDSoftCELoss from .kl_divergence import KLDivergence from .l2_loss import L2Loss from .relational_kd import AngleWiseRKD, DistanceWiseRKD @@ -9,5 +10,5 @@ from .weighted_soft_label_distillation import WSLD __all__ = [ 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', - 'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss' + 'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss' ] diff --git a/mmrazor/models/losses/kd_soft_ce_loss.py b/mmrazor/models/losses/kd_soft_ce_loss.py new file mode 100644 index 00000000..43a5f288 --- /dev/null +++ b/mmrazor/models/losses/kd_soft_ce_loss.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcls.models.losses.cross_entropy_loss import soft_cross_entropy + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class KDSoftCELoss(nn.Module): + """Distilling the Knowledge in a Neural Network, NIPS2014. Based on Soft + Cross Entropy criterion. + + https://arxiv.org/pdf/1503.02531.pdf + + + Args: + tau (int, optional): Temperature. Defaults to 1. + reduction (str): Specifies the reduction to apply to the loss: + ``'none'`` | ``'none'`` | ``'sum'`` | ``'mean'``. + ``'none'``: no reduction will be applied, + ``'sum'``: the output will be summed, + ``'mean'``: the output will be divided by the number of + elements in the output. + Default: ``'mean'`` + mult_tem_square (bool, optional): Multiply square of temperature + or not. Defaults to True. + loss_weight (float): Weight of loss. Defaults to 1.0. + """ + + def __init__( + self, + tau: float = 1.0, + reduction: str = 'mean', + mult_tem_square: bool = True, + loss_weight: float = 1.0, + ) -> None: + super().__init__() + self.tau = tau + self.mult_tem_square = mult_tem_square + self.loss_weight = loss_weight + self.cls_criterion = soft_cross_entropy + + accept_reduction = {None, 'none', 'mean', 'sum'} + assert reduction in accept_reduction, \ + f'KLDivergence supports reduction {accept_reduction}, ' \ + f'but gets {reduction}.' + self.reduction = reduction + + def forward( + self, + preds_S: torch.Tensor, + preds_T: torch.Tensor, + weight: torch.Tensor = None, + avg_factor: int = None, + reduction_override: str = None, + ) -> torch.Tensor: + """Forward computation. + + Args: + preds_S (torch.Tensor): The student model prediction with + shape (N, C). + preds_T (torch.Tensor): The teacher model prediction with + shape (N, C). + weight (torch.Tensor, optional): Sample-wise loss weight with + shape (N, C). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optiom): Override redunction in forward. + Defaults to None. + + Return: + torch.Tensor: The calculated loss value. + """ + reduction = ( + reduction_override if reduction_override else self.reduction) + + preds_S = preds_S / self.tau + soft_label = F.softmax((preds_T / self.tau), dim=-1) + loss_cls = self.loss_weight * self.cls_criterion( + preds_S, + soft_label, + weight, + reduction=reduction, + avg_factor=avg_factor) + if self.mult_tem_square: + loss_cls *= (self.tau**2) + return loss_cls diff --git a/mmrazor/models/losses/l2_loss.py b/mmrazor/models/losses/l2_loss.py index 8b373ed3..55313127 100644 --- a/mmrazor/models/losses/l2_loss.py +++ b/mmrazor/models/losses/l2_loss.py @@ -15,6 +15,8 @@ class L2Loss(nn.Module): mult (float): Multiplier for feature normalization. Defaults to 1.0. div_element (bool): Whether to divide the loss by element-wise. Defaults to False. + dist (bool): Whether to conduct two-norm dist as torch.dist(p=2). + Defaults to False. """ def __init__( @@ -23,12 +25,14 @@ class L2Loss(nn.Module): normalize: bool = True, mult: float = 1.0, div_element: bool = False, + dist: bool = False, ) -> None: super().__init__() self.loss_weight = loss_weight self.normalize = normalize self.mult = mult self.div_element = div_element + self.dist = dist def forward( self, @@ -49,10 +53,14 @@ class L2Loss(nn.Module): loss = torch.sum(torch.pow(torch.sub(s_feature, t_feature), 2)) - if self.div_element: - loss = loss / s_feature.numel() + # Calculate l2_loss as dist. + if self.dist: + loss = torch.sqrt(loss) else: - loss = loss / s_feature.size(0) + if self.div_element: + loss = loss / s_feature.numel() + else: + loss = loss / s_feature.size(0) return self.loss_weight * loss diff --git a/tests/test_models/test_algorithms/test_self_distill.py b/tests/test_models/test_algorithms/test_self_distill.py new file mode 100644 index 00000000..4a8cb223 --- /dev/null +++ b/tests/test_models/test_algorithms/test_self_distill.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmcv import ConfigDict + +from mmrazor.models import SelfDistill + + +class TestSelfDistill(TestCase): + + def test_init(self): + + student_recorders_cfg = ConfigDict( + conv=dict(type='ModuleOutputs', source='conv')) + teacher_recorders_cfg = ConfigDict( + conv=dict(type='ModuleOutputs', source='conv')) + + alg_kwargs = ConfigDict( + architecture=dict(type='ToyStudent'), + distiller=dict( + type='BYOTDistiller', + student_recorders=student_recorders_cfg, + teacher_recorders=teacher_recorders_cfg, + distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')), + loss_forward_mappings=dict( + loss_toy=dict( + arg1=dict(from_student=True, recorder='conv'), + arg2=dict(from_student=False, recorder='conv'))))) + + _ = SelfDistill(**alg_kwargs) + + def test_loss(self): + + student_recorders_cfg = ConfigDict( + conv=dict(type='ModuleOutputs', source='conv')) + teacher_recorders_cfg = ConfigDict( + conv=dict(type='ModuleOutputs', source='conv')) + + alg_kwargs = ConfigDict( + architecture=dict(type='ToyStudent'), + distiller=dict( + type='BYOTDistiller', + student_recorders=student_recorders_cfg, + teacher_recorders=teacher_recorders_cfg, + distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')), + loss_forward_mappings=dict( + loss_toy=dict( + arg1=dict(from_student=True, recorder='conv'), + arg2=dict(from_student=False, recorder='conv'))))) + + img = torch.randn(1, 3, 1, 1) + + alg = SelfDistill(**alg_kwargs) + losses = alg(img, mode='loss') + self.assertIn('distill.loss_toy', losses) + self.assertIn('student.loss', losses) diff --git a/tests/test_models/test_architectures/test_connectors/test_connectors.py b/tests/test_models/test_architectures/test_connectors/test_connectors.py index b9c6f608..1937f588 100644 --- a/tests/test_models/test_architectures/test_connectors/test_connectors.py +++ b/tests/test_models/test_architectures/test_connectors/test_connectors.py @@ -3,7 +3,7 @@ from unittest import TestCase import torch -from mmrazor.models import ConvModuleConncetor +from mmrazor.models import BYOTConnector, ConvModuleConncetor class TestConnector(TestCase): @@ -36,3 +36,23 @@ class TestConnector(TestCase): convmodule_connector_cfg['conv_cfg'] = 'conv2d' with self.assertRaises(AssertionError): _ = ConvModuleConncetor(**convmodule_connector_cfg) + + def test_byot_connector(self): + byot_connector_cfg = dict( + in_channel=16, + out_channel=32, + num_classes=10, + expansion=4, + pool_size=4, + kernel_size=3, + stride=2, + init_cfg=None) + byot_connector = BYOTConnector(**byot_connector_cfg) + + s_feat = torch.randn(1, 16 * 4, 8, 8) + t_feat = torch.randn(1, 32 * 4) + labels = torch.randn(1, 10) + + output, logits = byot_connector.forward_train(s_feat) + assert output.size() == t_feat.size() + assert logits.size() == labels.size() diff --git a/tests/test_models/test_distillers/test_byot_distill.py b/tests/test_models/test_distillers/test_byot_distill.py new file mode 100644 index 00000000..004eb620 --- /dev/null +++ b/tests/test_models/test_distillers/test_byot_distill.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest import TestCase + +from mmcv import ConfigDict + +from mmrazor.models import BYOTDistiller + + +class TestBYOTDistiller(TestCase): + + def test_init(self): + + student_recorders_cfg = ConfigDict( + conv=dict(type='ModuleOutputs', source='conv')) + teacher_recorders_cfg = ConfigDict( + conv=dict(type='ModuleOutputs', source='conv')) + + distiller_kwargs = ConfigDict( + student_recorders=student_recorders_cfg, + teacher_recorders=teacher_recorders_cfg, + distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')), + loss_forward_mappings=dict( + loss_toy=dict( + arg1=dict(from_student=True, recorder='conv'), + arg2=dict(from_student=False, recorder='conv'), + )), + ) + + _ = BYOTDistiller(**distiller_kwargs) + + distiller_kwargs_ = copy.deepcopy(distiller_kwargs) + distiller_kwargs_['distill_losses'] = None + with self.assertRaisesRegex(AssertionError, + '"loss_toy" is not in distill'): + _ = BYOTDistiller(**distiller_kwargs_) + + distiller_kwargs_ = copy.deepcopy(distiller_kwargs) + distiller_kwargs_['distill_losses'] = dict( + toy=dict(type='ToyDistillLoss')) + distiller_kwargs_['loss_forward_mappings'] = dict( + toy=dict( + arg1=dict(from_student=True, recorder='conv'), + arg2=dict(from_student=False, recorder='conv'))) + with self.assertWarnsRegex(UserWarning, 'Warning: If toy is a'): + _ = BYOTDistiller(**distiller_kwargs_) + + distiller_kwargs_ = copy.deepcopy(distiller_kwargs) + distiller_kwargs_['loss_forward_mappings'] = None + _ = BYOTDistiller(**distiller_kwargs_) + + distiller_kwargs_ = copy.deepcopy(distiller_kwargs) + distiller_kwargs_['loss_forward_mappings'] = list('AAA') + + with self.assertRaisesRegex(TypeError, + 'loss_forward_mappings should be '): + _ = BYOTDistiller(**distiller_kwargs_) + + distiller_kwargs_ = copy.deepcopy(distiller_kwargs) + distiller_kwargs_['loss_forward_mappings']['loss_toy'] = list() + with self.assertRaisesRegex( + TypeError, 'Each item of loss_forward_mappings should be '): + _ = BYOTDistiller(**distiller_kwargs_) + + distiller_kwargs_ = copy.deepcopy(distiller_kwargs) + distiller_kwargs_.loss_forward_mappings.loss_toy.arg1.from_student = '' + with self.assertRaisesRegex(TypeError, + 'from_student should be a bool'): + _ = BYOTDistiller(**distiller_kwargs_) diff --git a/tests/test_models/test_losses/test_distillation_losses.py b/tests/test_models/test_losses/test_distillation_losses.py index 6444872f..c1d6030b 100644 --- a/tests/test_models/test_losses/test_distillation_losses.py +++ b/tests/test_models/test_losses/test_distillation_losses.py @@ -3,7 +3,7 @@ from unittest import TestCase import torch -from mmrazor.models import ABLoss, DKDLoss +from mmrazor.models import ABLoss, DKDLoss, KDSoftCELoss class TestLosses(TestCase): @@ -50,3 +50,9 @@ class TestLosses(TestCase): dkd_loss = DKDLoss(**dkd_loss_cfg) # dkd requires label logits self.normal_test_1d(dkd_loss, labels=True) + + def test_kdSoftce_loss(self): + kdSoftce_loss_cfg = dict(loss_weight=1.0) + kdSoftce_loss = KDSoftCELoss(**kdSoftce_loss_cfg) + # kd soft ce loss requires label logits + self.normal_test_1d(kdSoftce_loss, labels=True) diff --git a/tests/test_runners/test_distill_val_loop.py b/tests/test_runners/test_distill_val_loop.py index af5fe9be..8bf28516 100644 --- a/tests/test_runners/test_distill_val_loop.py +++ b/tests/test_runners/test_distill_val_loop.py @@ -13,7 +13,8 @@ from mmengine.model import BaseModel from mmengine.runner import Runner from torch.utils.data import Dataset -from mmrazor.engine import SingleTeacherDistillValLoop # noqa: F401 +from mmrazor.engine import SelfDistillValLoop # noqa: F401 +from mmrazor.engine import SingleTeacherDistillValLoop from mmrazor.registry import DATASETS, METRICS, MODELS @@ -125,3 +126,54 @@ class TestSingleTeacherDistillValLoop(TestCase): runner.val() self.assertIn('val/teacher.acc', runner.message_hub.log_scalars.keys()) + + +class TestSelfDistillValLoop(TestCase): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + + val_dataloader = dict( + dataset=dict(type='ToyDataset_DistillValLoop'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0) + val_evaluator = dict(type='ToyMetric_DistillValLoop') + + val_loop_cfg = dict( + default_scope='mmrazor', + model=dict(type='ToyModel_DistillValLoop'), + work_dir=self.temp_dir, + val_dataloader=val_dataloader, + val_evaluator=val_evaluator, + val_cfg=dict(type='SelfDistillValLoop'), + custom_hooks=[], + default_hooks=dict( + runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', interval=1, by_epoch=True), + sampler_seed=dict(type='DistSamplerSeedHook')), + launcher='none', + env_cfg=dict(dist_cfg=dict(backend='nccl')), + ) + self.val_loop_cfg = Config(val_loop_cfg) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.val_loop_cfg) + cfg.experiment_name = 'test_init_self' + runner = Runner.from_cfg(cfg) + loop = runner.build_val_loop(cfg.val_cfg) + + self.assertIsInstance(loop, SelfDistillValLoop) + + def test_run(self): + cfg = copy.deepcopy(self.val_loop_cfg) + cfg.experiment_name = 'test_run_self' + runner = Runner.from_cfg(cfg) + runner.val()